From 0259154577889946a330b7a040f61379e56a5d06 Mon Sep 17 00:00:00 2001 From: Eric Nielsen <4120606+ericbn@users.noreply.github.com> Date: Mon, 17 Mar 2025 19:45:21 -0500 Subject: [PATCH 1/2] feat(data_classes): add API Gateway Websocket event Pydantic models were added for API Gateway WebSocket events here: https://docs.powertools.aws.dev/lambda/python/latest/utilities/parser/#built-in-models Add the corresponding data model class. --- .../utilities/data_classes/__init__.py | 2 + .../api_gateway_websocket_event.py | 128 ++++++++++++++++++ .../test_api_gateway_websocket_event.py | 95 +++++++++++++ 3 files changed, 225 insertions(+) create mode 100644 aws_lambda_powertools/utilities/data_classes/api_gateway_websocket_event.py create mode 100644 tests/unit/data_classes/required_dependencies/test_api_gateway_websocket_event.py diff --git a/aws_lambda_powertools/utilities/data_classes/__init__.py b/aws_lambda_powertools/utilities/data_classes/__init__.py index 8d20de7d192..2757725dc62 100644 --- a/aws_lambda_powertools/utilities/data_classes/__init__.py +++ b/aws_lambda_powertools/utilities/data_classes/__init__.py @@ -4,6 +4,7 @@ from .alb_event import ALBEvent from .api_gateway_proxy_event import APIGatewayProxyEvent, APIGatewayProxyEventV2 +from .api_gateway_websocket_event import APIGatewayWebSocketEvent from .appsync_resolver_event import AppSyncResolverEvent from .aws_config_rule_event import AWSConfigRuleEvent from .bedrock_agent_event import BedrockAgentEvent @@ -51,6 +52,7 @@ __all__ = [ "APIGatewayProxyEvent", "APIGatewayProxyEventV2", + "APIGatewayWebSocketEvent", "SecretsManagerEvent", "AppSyncResolverEvent", "ALBEvent", diff --git a/aws_lambda_powertools/utilities/data_classes/api_gateway_websocket_event.py b/aws_lambda_powertools/utilities/data_classes/api_gateway_websocket_event.py new file mode 100644 index 00000000000..f71e236f874 --- /dev/null +++ b/aws_lambda_powertools/utilities/data_classes/api_gateway_websocket_event.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +import base64 +from functools import cached_property +from typing import Any + +from aws_lambda_powertools.utilities.data_classes.common import ( + CaseInsensitiveDict, + DictWrapper, +) + + +class APIGatewayWebSocketEventIdentity(DictWrapper): + @property + def source_ip(self) -> str: + return self["sourceIp"] + + @property + def user_agent(self) -> str | None: + return self.get("userAgent") + + +class APIGatewayWebSocketEventRequestContext(DictWrapper): + @property + def route_key(self) -> str: + return self["routeKey"] + + @property + def disconnect_status_code(self) -> int | None: + return self.get("disconnectStatusCode") + + @property + def message_id(self) -> str | None: + return self.get("messageId") + + @property + def event_type(self) -> str: + return self["eventType"] + + @property + def extended_request_id(self) -> str: + return self["extendedRequestId"] + + @property + def request_time(self) -> str: + return self["requestTime"] + + @property + def message_direction(self) -> str: + return self["messageDirection"] + + @property + def disconnect_reason(self) -> str | None: + return self.get("disconnectReason") + + @property + def stage(self) -> str: + return self["stage"] + + @property + def connected_at(self) -> int: + return self["connectedAt"] + + @property + def request_time_epoch(self) -> int: + return self["requestTimeEpoch"] + + @property + def identity(self) -> APIGatewayWebSocketEventIdentity: + return APIGatewayWebSocketEventIdentity(self["identity"]) + + @property + def request_id(self) -> str: + return self["requestId"] + + @property + def domain_name(self) -> str: + return self["domainName"] + + @property + def connection_id(self) -> str: + return self["connectionId"] + + @property + def api_id(self) -> str: + return self["apiId"] + + +class APIGatewayWebSocketEvent(DictWrapper): + """AWS proxy integration event for WebSocket API + + Documentation: + -------------- + - https://docs.aws.amazon.com/apigateway/latest/developerguide/apigateway-websocket-api-integration-requests.html + """ + + @property + def is_base64_encoded(self) -> bool: + return self["isBase64Encoded"] + + @property + def body(self) -> str | None: + return self.get("body") + + @cached_property + def decoded_body(self) -> str | None: + body = self.body + if self.is_base64_encoded and body: + return base64.b64decode(body.encode()).decode() + return body + + @cached_property + def json_body(self) -> Any: + if self.decoded_body: + return self._json_deserializer(self.decoded_body) + return None + + @property + def headers(self) -> dict[str, str]: + return CaseInsensitiveDict(self.get("headers")) + + @property + def multi_value_headers(self) -> dict[str, list[str]]: + return CaseInsensitiveDict(self.get("multiValueHeaders")) + + @property + def request_context(self) -> APIGatewayWebSocketEventRequestContext: + return APIGatewayWebSocketEventRequestContext(self["requestContext"]) diff --git a/tests/unit/data_classes/required_dependencies/test_api_gateway_websocket_event.py b/tests/unit/data_classes/required_dependencies/test_api_gateway_websocket_event.py new file mode 100644 index 00000000000..4381cd7cefb --- /dev/null +++ b/tests/unit/data_classes/required_dependencies/test_api_gateway_websocket_event.py @@ -0,0 +1,95 @@ +from aws_lambda_powertools.utilities.data_classes import APIGatewayWebSocketEvent +from tests.functional.utils import load_event + + +def test_connect_api_gateway_websocket_event(): + raw_event = load_event("apiGatewayWebSocketApiConnect.json") + parsed_event = APIGatewayWebSocketEvent(raw_event) + + assert parsed_event.body is None + assert parsed_event.headers == raw_event["headers"] + assert parsed_event.multi_value_headers == raw_event["multiValueHeaders"] + + request_context = parsed_event.request_context + request_context_raw = raw_event["requestContext"] + assert request_context.route_key == request_context_raw["routeKey"] + assert request_context.disconnect_status_code is None + assert request_context.message_id is None + assert request_context.event_type == request_context_raw["eventType"] + assert request_context.extended_request_id == request_context_raw["extendedRequestId"] + assert request_context.request_time == request_context_raw["requestTime"] + assert request_context.message_direction == request_context_raw["messageDirection"] + assert request_context.disconnect_reason is None + assert request_context.stage == request_context_raw["stage"] + assert request_context.connected_at == request_context_raw["connectedAt"] + assert request_context.request_time_epoch == request_context_raw["requestTimeEpoch"] + assert request_context.request_id == request_context_raw["requestId"] + assert request_context.domain_name == request_context_raw["domainName"] + assert request_context.connection_id == request_context_raw["connectionId"] + assert request_context.api_id == request_context_raw["apiId"] + + identity = request_context.identity + identity_raw = request_context_raw["identity"] + assert identity.source_ip == identity_raw["sourceIp"] + + +def test_disconnect_api_gateway_websocket_event(): + raw_event = load_event("apiGatewayWebSocketApiDisconnect.json") + parsed_event = APIGatewayWebSocketEvent(raw_event) + + assert parsed_event.body is None + assert parsed_event.headers == raw_event["headers"] + assert parsed_event.multi_value_headers == raw_event["multiValueHeaders"] + + request_context = parsed_event.request_context + request_context_raw = raw_event["requestContext"] + assert request_context.route_key == request_context_raw["routeKey"] + assert request_context.disconnect_status_code == request_context_raw["disconnectStatusCode"] + assert request_context.message_id is None + assert request_context.event_type == request_context_raw["eventType"] + assert request_context.extended_request_id == request_context_raw["extendedRequestId"] + assert request_context.request_time == request_context_raw["requestTime"] + assert request_context.message_direction == request_context_raw["messageDirection"] + assert request_context.disconnect_reason == request_context_raw["disconnectReason"] + assert request_context.stage == request_context_raw["stage"] + assert request_context.connected_at == request_context_raw["connectedAt"] + assert request_context.request_time_epoch == request_context_raw["requestTimeEpoch"] + assert request_context.request_id == request_context_raw["requestId"] + assert request_context.domain_name == request_context_raw["domainName"] + assert request_context.connection_id == request_context_raw["connectionId"] + assert request_context.api_id == request_context_raw["apiId"] + + identity = request_context.identity + identity_raw = request_context_raw["identity"] + assert identity.source_ip == identity_raw["sourceIp"] + + +def test_message_api_gateway_websocket_event(): + raw_event = load_event("apiGatewayWebSocketApiMessage.json") + parsed_event = APIGatewayWebSocketEvent(raw_event) + + assert parsed_event.body == raw_event["body"] + assert parsed_event.headers == {} + assert parsed_event.multi_value_headers == {} + + request_context = parsed_event.request_context + request_context_raw = raw_event["requestContext"] + assert request_context.route_key == request_context_raw["routeKey"] + assert request_context.disconnect_status_code is None + assert request_context.message_id == request_context_raw["messageId"] + assert request_context.event_type == request_context_raw["eventType"] + assert request_context.extended_request_id == request_context_raw["extendedRequestId"] + assert request_context.request_time == request_context_raw["requestTime"] + assert request_context.message_direction == request_context_raw["messageDirection"] + assert request_context.disconnect_reason is None + assert request_context.stage == request_context_raw["stage"] + assert request_context.connected_at == request_context_raw["connectedAt"] + assert request_context.request_time_epoch == request_context_raw["requestTimeEpoch"] + assert request_context.request_id == request_context_raw["requestId"] + assert request_context.domain_name == request_context_raw["domainName"] + assert request_context.connection_id == request_context_raw["connectionId"] + assert request_context.api_id == request_context_raw["apiId"] + + identity = request_context.identity + identity_raw = request_context_raw["identity"] + assert identity.source_ip == identity_raw["sourceIp"] From 0153eadfbbcf03c639cba435da0ff35ff29675eb Mon Sep 17 00:00:00 2001 From: Eric Nielsen <4120606+ericbn@users.noreply.github.com> Date: Tue, 18 Mar 2025 08:20:23 -0500 Subject: [PATCH 2/2] feat(data_classes): increase tests code coverage --- .../test_api_gateway_websocket_event.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/unit/data_classes/required_dependencies/test_api_gateway_websocket_event.py b/tests/unit/data_classes/required_dependencies/test_api_gateway_websocket_event.py index 4381cd7cefb..faee5b17289 100644 --- a/tests/unit/data_classes/required_dependencies/test_api_gateway_websocket_event.py +++ b/tests/unit/data_classes/required_dependencies/test_api_gateway_websocket_event.py @@ -1,3 +1,5 @@ +import json + from aws_lambda_powertools.utilities.data_classes import APIGatewayWebSocketEvent from tests.functional.utils import load_event @@ -6,7 +8,10 @@ def test_connect_api_gateway_websocket_event(): raw_event = load_event("apiGatewayWebSocketApiConnect.json") parsed_event = APIGatewayWebSocketEvent(raw_event) + assert parsed_event.is_base64_encoded is False assert parsed_event.body is None + assert parsed_event.decoded_body is None + assert parsed_event.json_body is None assert parsed_event.headers == raw_event["headers"] assert parsed_event.multi_value_headers == raw_event["multiValueHeaders"] @@ -31,13 +36,17 @@ def test_connect_api_gateway_websocket_event(): identity = request_context.identity identity_raw = request_context_raw["identity"] assert identity.source_ip == identity_raw["sourceIp"] + assert identity.user_agent is None def test_disconnect_api_gateway_websocket_event(): raw_event = load_event("apiGatewayWebSocketApiDisconnect.json") parsed_event = APIGatewayWebSocketEvent(raw_event) + assert parsed_event.is_base64_encoded is False assert parsed_event.body is None + assert parsed_event.decoded_body is None + assert parsed_event.json_body is None assert parsed_event.headers == raw_event["headers"] assert parsed_event.multi_value_headers == raw_event["multiValueHeaders"] @@ -62,13 +71,17 @@ def test_disconnect_api_gateway_websocket_event(): identity = request_context.identity identity_raw = request_context_raw["identity"] assert identity.source_ip == identity_raw["sourceIp"] + assert identity.user_agent is None def test_message_api_gateway_websocket_event(): raw_event = load_event("apiGatewayWebSocketApiMessage.json") parsed_event = APIGatewayWebSocketEvent(raw_event) + assert parsed_event.is_base64_encoded is False assert parsed_event.body == raw_event["body"] + assert parsed_event.decoded_body == raw_event["body"] + assert parsed_event.json_body == json.loads(raw_event["body"]) assert parsed_event.headers == {} assert parsed_event.multi_value_headers == {} @@ -93,3 +106,4 @@ def test_message_api_gateway_websocket_event(): identity = request_context.identity identity_raw = request_context_raw["identity"] assert identity.source_ip == identity_raw["sourceIp"] + assert identity.user_agent is None