15
15
Any ,
16
16
Callable ,
17
17
Dict ,
18
+ Generic ,
18
19
List ,
19
20
Match ,
20
21
Optional ,
23
24
Set ,
24
25
Tuple ,
25
26
Type ,
27
+ TypeVar ,
26
28
Union ,
27
29
cast ,
28
30
)
45
47
ALBEvent ,
46
48
APIGatewayProxyEvent ,
47
49
APIGatewayProxyEventV2 ,
50
+ BedrockAgentEvent ,
48
51
LambdaFunctionUrlEvent ,
49
52
VPCLatticeEvent ,
50
53
VPCLatticeEventV2 ,
62
65
_DEFAULT_OPENAPI_RESPONSE_DESCRIPTION = "Successful Response"
63
66
_ROUTE_REGEX = "^{}$"
64
67
68
+ ResponseEventT = TypeVar ("ResponseEventT" , bound = BaseProxyEvent )
69
+
65
70
if TYPE_CHECKING :
66
71
from aws_lambda_powertools .event_handler .openapi .compat import (
67
72
JsonSchemaValue ,
@@ -85,6 +90,7 @@ class ProxyEventType(Enum):
85
90
APIGatewayProxyEvent = "APIGatewayProxyEvent"
86
91
APIGatewayProxyEventV2 = "APIGatewayProxyEventV2"
87
92
ALBEvent = "ALBEvent"
93
+ BedrockAgentEvent = "BedrockAgentEvent"
88
94
VPCLatticeEvent = "VPCLatticeEvent"
89
95
VPCLatticeEventV2 = "VPCLatticeEventV2"
90
96
LambdaFunctionUrlEvent = "LambdaFunctionUrlEvent"
@@ -208,7 +214,7 @@ def __init__(
208
214
self ,
209
215
status_code : int ,
210
216
content_type : Optional [str ] = None ,
211
- body : Union [ str , bytes , None ] = None ,
217
+ body : Any = None ,
212
218
headers : Optional [Dict [str , Union [str , List [str ]]]] = None ,
213
219
cookies : Optional [List [Cookie ]] = None ,
214
220
compress : Optional [bool ] = None ,
@@ -235,6 +241,7 @@ def __init__(
235
241
self .headers : Dict [str , Union [str , List [str ]]] = headers if headers else {}
236
242
self .cookies = cookies or []
237
243
self .compress = compress
244
+ self .content_type = content_type
238
245
if content_type :
239
246
self .headers .setdefault ("Content-Type" , content_type )
240
247
@@ -689,14 +696,14 @@ def _generate_operation_id(self) -> str:
689
696
return operation_id
690
697
691
698
692
- class ResponseBuilder :
699
+ class ResponseBuilder ( Generic [ ResponseEventT ]) :
693
700
"""Internally used Response builder"""
694
701
695
702
def __init__ (self , response : Response , route : Optional [Route ] = None ):
696
703
self .response = response
697
704
self .route = route
698
705
699
- def _add_cors (self , event : BaseProxyEvent , cors : CORSConfig ):
706
+ def _add_cors (self , event : ResponseEventT , cors : CORSConfig ):
700
707
"""Update headers to include the configured Access-Control headers"""
701
708
self .response .headers .update (cors .to_dict (event .get_header_value ("Origin" )))
702
709
@@ -709,7 +716,7 @@ def _add_cache_control(self, cache_control: str):
709
716
def _has_compression_enabled (
710
717
route_compression : bool ,
711
718
response_compression : Optional [bool ],
712
- event : BaseProxyEvent ,
719
+ event : ResponseEventT ,
713
720
) -> bool :
714
721
"""
715
722
Checks if compression is enabled.
@@ -722,7 +729,7 @@ def _has_compression_enabled(
722
729
A boolean indicating whether compression is enabled or not in the route setting.
723
730
response_compression: bool, optional
724
731
A boolean indicating whether compression is enabled or not in the response setting.
725
- event: BaseProxyEvent
732
+ event: ResponseEventT
726
733
The event object containing the request details.
727
734
728
735
Returns
@@ -752,7 +759,7 @@ def _compress(self):
752
759
gzip = zlib .compressobj (9 , zlib .DEFLATED , zlib .MAX_WBITS | 16 )
753
760
self .response .body = gzip .compress (self .response .body ) + gzip .flush ()
754
761
755
- def _route (self , event : BaseProxyEvent , cors : Optional [CORSConfig ]):
762
+ def _route (self , event : ResponseEventT , cors : Optional [CORSConfig ]):
756
763
"""Optionally handle any of the route's configure response handling"""
757
764
if self .route is None :
758
765
return
@@ -767,7 +774,7 @@ def _route(self, event: BaseProxyEvent, cors: Optional[CORSConfig]):
767
774
):
768
775
self ._compress ()
769
776
770
- def build (self , event : BaseProxyEvent , cors : Optional [CORSConfig ] = None ) -> Dict [str , Any ]:
777
+ def build (self , event : ResponseEventT , cors : Optional [CORSConfig ] = None ) -> Dict [str , Any ]:
771
778
"""Build the full response dict to be returned by the lambda"""
772
779
self ._route (event , cors )
773
780
@@ -1315,6 +1322,7 @@ def __init__(
1315
1322
self ._strip_prefixes = strip_prefixes
1316
1323
self .context : Dict = {} # early init as customers might add context before event resolution
1317
1324
self .processed_stack_frames = []
1325
+ self ._response_builder_class = ResponseBuilder [BaseProxyEvent ]
1318
1326
1319
1327
# Allow for a custom serializer or a concise json serialization
1320
1328
self ._serializer = serializer or partial (json .dumps , separators = ("," , ":" ), cls = Encoder )
@@ -1784,14 +1792,17 @@ def _compile_regex(rule: str, base_regex: str = _ROUTE_REGEX):
1784
1792
rule_regex : str = re .sub (_DYNAMIC_ROUTE_PATTERN , _NAMED_GROUP_BOUNDARY_PATTERN , rule )
1785
1793
return re .compile (base_regex .format (rule_regex ))
1786
1794
1787
- def _to_proxy_event (self , event : Dict ) -> BaseProxyEvent :
1795
+ def _to_proxy_event (self , event : Dict ) -> BaseProxyEvent : # noqa: PLR0911 # ignore many returns
1788
1796
"""Convert the event dict to the corresponding data class"""
1789
1797
if self ._proxy_type == ProxyEventType .APIGatewayProxyEvent :
1790
1798
logger .debug ("Converting event to API Gateway REST API contract" )
1791
1799
return APIGatewayProxyEvent (event )
1792
1800
if self ._proxy_type == ProxyEventType .APIGatewayProxyEventV2 :
1793
1801
logger .debug ("Converting event to API Gateway HTTP API contract" )
1794
1802
return APIGatewayProxyEventV2 (event )
1803
+ if self ._proxy_type == ProxyEventType .BedrockAgentEvent :
1804
+ logger .debug ("Converting event to Bedrock Agent contract" )
1805
+ return BedrockAgentEvent (event )
1795
1806
if self ._proxy_type == ProxyEventType .LambdaFunctionUrlEvent :
1796
1807
logger .debug ("Converting event to Lambda Function URL contract" )
1797
1808
return LambdaFunctionUrlEvent (event )
@@ -1869,9 +1880,9 @@ def _not_found(self, method: str) -> ResponseBuilder:
1869
1880
1870
1881
handler = self ._lookup_exception_handler (NotFoundError )
1871
1882
if handler :
1872
- return ResponseBuilder (handler (NotFoundError ()))
1883
+ return self . _response_builder_class (handler (NotFoundError ()))
1873
1884
1874
- return ResponseBuilder (
1885
+ return self . _response_builder_class (
1875
1886
Response (
1876
1887
status_code = HTTPStatus .NOT_FOUND .value ,
1877
1888
content_type = content_types .APPLICATION_JSON ,
@@ -1886,7 +1897,7 @@ def _call_route(self, route: Route, route_arguments: Dict[str, str]) -> Response
1886
1897
# Reset Processed stack for Middleware (for debugging purposes)
1887
1898
self ._reset_processed_stack ()
1888
1899
1889
- return ResponseBuilder (
1900
+ return self . _response_builder_class (
1890
1901
self ._to_response (
1891
1902
route (router_middlewares = self ._router_middlewares , app = self , route_arguments = route_arguments ),
1892
1903
),
@@ -1903,7 +1914,7 @@ def _call_route(self, route: Route, route_arguments: Dict[str, str]) -> Response
1903
1914
# If the user has turned on debug mode,
1904
1915
# we'll let the original exception propagate, so
1905
1916
# they get more information about what went wrong.
1906
- return ResponseBuilder (
1917
+ return self . _response_builder_class (
1907
1918
Response (
1908
1919
status_code = 500 ,
1909
1920
content_type = content_types .TEXT_PLAIN ,
@@ -1942,12 +1953,12 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> Optional[Resp
1942
1953
handler = self ._lookup_exception_handler (type (exp ))
1943
1954
if handler :
1944
1955
try :
1945
- return ResponseBuilder (handler (exp ), route )
1956
+ return self . _response_builder_class (handler (exp ), route )
1946
1957
except ServiceError as service_error :
1947
1958
exp = service_error
1948
1959
1949
1960
if isinstance (exp , ServiceError ):
1950
- return ResponseBuilder (
1961
+ return self . _response_builder_class (
1951
1962
Response (
1952
1963
status_code = exp .status_code ,
1953
1964
content_type = content_types .APPLICATION_JSON ,
0 commit comments