Skip to content

Commit a7f4aa3

Browse files
committed
fix: use generics
1 parent 021f6c6 commit a7f4aa3

File tree

2 files changed

+16
-15
lines changed

2 files changed

+16
-15
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
Any,
1616
Callable,
1717
Dict,
18+
Generic,
1819
List,
1920
Match,
2021
Optional,
@@ -23,6 +24,7 @@
2324
Set,
2425
Tuple,
2526
Type,
27+
TypeVar,
2628
Union,
2729
cast,
2830
)
@@ -63,6 +65,8 @@
6365
_DEFAULT_OPENAPI_RESPONSE_DESCRIPTION = "Successful Response"
6466
_ROUTE_REGEX = "^{}$"
6567

68+
ResponseEventT = TypeVar("ResponseEventT", bound=BaseProxyEvent)
69+
6670
if TYPE_CHECKING:
6771
from aws_lambda_powertools.event_handler.openapi.compat import (
6872
JsonSchemaValue,
@@ -691,14 +695,14 @@ def _generate_operation_id(self) -> str:
691695
return operation_id
692696

693697

694-
class ResponseBuilder:
698+
class ResponseBuilder(Generic[ResponseEventT]):
695699
"""Internally used Response builder"""
696700

697701
def __init__(self, response: Response, route: Optional[Route] = None):
698702
self.response = response
699703
self.route = route
700704

701-
def _add_cors(self, event: BaseProxyEvent, cors: CORSConfig):
705+
def _add_cors(self, event: ResponseEventT, cors: CORSConfig):
702706
"""Update headers to include the configured Access-Control headers"""
703707
self.response.headers.update(cors.to_dict(event.get_header_value("Origin")))
704708

@@ -711,7 +715,7 @@ def _add_cache_control(self, cache_control: str):
711715
def _has_compression_enabled(
712716
route_compression: bool,
713717
response_compression: Optional[bool],
714-
event: BaseProxyEvent,
718+
event: ResponseEventT,
715719
) -> bool:
716720
"""
717721
Checks if compression is enabled.
@@ -724,7 +728,7 @@ def _has_compression_enabled(
724728
A boolean indicating whether compression is enabled or not in the route setting.
725729
response_compression: bool, optional
726730
A boolean indicating whether compression is enabled or not in the response setting.
727-
event: BaseProxyEvent
731+
event: Generic[ResponseEventT]
728732
The event object containing the request details.
729733
730734
Returns
@@ -754,7 +758,7 @@ def _compress(self):
754758
gzip = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS | 16)
755759
self.response.body = gzip.compress(self.response.body) + gzip.flush()
756760

757-
def _route(self, event: BaseProxyEvent, cors: Optional[CORSConfig]):
761+
def _route(self, event: ResponseEventT, cors: Optional[CORSConfig]):
758762
"""Optionally handle any of the route's configure response handling"""
759763
if self.route is None:
760764
return
@@ -769,7 +773,7 @@ def _route(self, event: BaseProxyEvent, cors: Optional[CORSConfig]):
769773
):
770774
self._compress()
771775

772-
def build(self, event: BaseProxyEvent, cors: Optional[CORSConfig] = None) -> Dict[str, Any]:
776+
def build(self, event: ResponseEventT, cors: Optional[CORSConfig] = None) -> Dict[str, Any]:
773777
"""Build the full response dict to be returned by the lambda"""
774778
self._route(event, cors)
775779

@@ -1317,7 +1321,7 @@ def __init__(
13171321
self._strip_prefixes = strip_prefixes
13181322
self.context: Dict = {} # early init as customers might add context before event resolution
13191323
self.processed_stack_frames = []
1320-
self._response_builder_class = ResponseBuilder
1324+
self._response_builder_class = ResponseBuilder[BaseProxyEvent]
13211325

13221326
# Allow for a custom serializer or a concise json serialization
13231327
self._serializer = serializer or partial(json.dumps, separators=(",", ":"), cls=Encoder)

aws_lambda_powertools/event_handler/bedrock_agent.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,27 @@
11
import logging
2-
from typing import Any, Dict, Optional, cast
2+
from typing import Any, Dict, Optional
33

44
from typing_extensions import override
55

66
from aws_lambda_powertools.event_handler import ApiGatewayResolver
77
from aws_lambda_powertools.event_handler.api_gateway import CORSConfig, ProxyEventType, ResponseBuilder
88
from aws_lambda_powertools.utilities.data_classes import BedrockAgentEvent
9-
from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent
109

1110
logger = logging.getLogger(__name__)
1211

1312

1413
class BedrockResponseBuilder(ResponseBuilder):
1514
@override
16-
def build(self, event: BaseProxyEvent, cors: Optional[CORSConfig] = None) -> Dict[str, Any]:
15+
def build(self, event: BedrockAgentEvent, cors: Optional[CORSConfig] = None) -> Dict[str, Any]:
1716
"""Build the full response dict to be returned by the lambda"""
1817
self._route(event, cors)
1918

20-
bedrock_event = cast(BedrockAgentEvent, event)
21-
2219
return {
2320
"messageVersion": "1.0",
2421
"response": {
25-
"actionGroup": bedrock_event.action_group,
26-
"apiPath": bedrock_event.api_path,
27-
"httpMethod": bedrock_event.http_method,
22+
"actionGroup": event.action_group,
23+
"apiPath": event.api_path,
24+
"httpMethod": event.http_method,
2825
"httpStatusCode": self.response.status_code,
2926
"responseBody": {
3027
"application/json": {

0 commit comments

Comments
 (0)