Skip to content

Commit 05a2439

Browse files
Addressing feedback
1 parent c2aa872 commit 05a2439

File tree

2 files changed

+29
-25
lines changed

2 files changed

+29
-25
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

+3-25
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
validation_error_definition,
4444
validation_error_response_definition,
4545
)
46-
from aws_lambda_powertools.event_handler.util import _FrozenDict
46+
from aws_lambda_powertools.event_handler.util import _FrozenDict, extract_origin_header
4747
from aws_lambda_powertools.shared.cookies import Cookie
4848
from aws_lambda_powertools.shared.functions import powertools_dev_is_set
4949
from aws_lambda_powertools.shared.json_encoder import Encoder
@@ -58,7 +58,6 @@
5858
VPCLatticeEventV2,
5959
)
6060
from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent
61-
from aws_lambda_powertools.utilities.data_classes.shared_functions import get_header_value
6261
from aws_lambda_powertools.utilities.typing import LambdaContext
6362

6463
logger = logging.getLogger(__name__)
@@ -218,27 +217,6 @@ def to_dict(self, origin: Optional[str]) -> Dict[str, str]:
218217
headers["Access-Control-Allow-Credentials"] = "true"
219218
return headers
220219

221-
@staticmethod
222-
def extract_origin_header(resolver_headers: Dict):
223-
"""
224-
Extracts the 'origin' or 'Origin' header from the provided resolver headers.
225-
226-
The 'origin' or 'Origin' header can be either a single header or a multi-header.
227-
228-
Args:
229-
resolver_headers (Dict): A dictionary containing the headers.
230-
231-
Returns:
232-
Optional[str]: The value(s) of the origin header or None.
233-
"""
234-
resolved_header = get_header_value(resolver_headers, "origin", None, case_sensitive=False)
235-
if isinstance(resolved_header, str):
236-
return resolved_header
237-
if isinstance(resolved_header, list):
238-
return resolved_header[0]
239-
240-
return resolved_header
241-
242220

243221
class Response(Generic[ResponseT]):
244222
"""Response data class that provides greater control over what is returned from the proxy event"""
@@ -804,7 +782,7 @@ def __init__(
804782

805783
def _add_cors(self, event: ResponseEventT, cors: CORSConfig):
806784
"""Update headers to include the configured Access-Control headers"""
807-
extracted_origin_header = cors.extract_origin_header(event.resolved_headers_field)
785+
extracted_origin_header = extract_origin_header(event.resolved_headers_field)
808786
self.response.headers.update(cors.to_dict(extracted_origin_header))
809787

810788
def _add_cache_control(self, cache_control: str):
@@ -2152,7 +2130,7 @@ def _not_found(self, method: str) -> ResponseBuilder:
21522130
headers = {}
21532131
if self._cors:
21542132
logger.debug("CORS is enabled, updating headers.")
2155-
extracted_origin_header = self._cors.extract_origin_header(self.current_event.resolved_headers_field)
2133+
extracted_origin_header = extract_origin_header(self.current_event.resolved_headers_field)
21562134
headers.update(self._cors.to_dict(extracted_origin_header))
21572135

21582136
if method == "OPTIONS":

aws_lambda_powertools/event_handler/util.py

+26
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
from typing import Any, Dict
2+
3+
from aws_lambda_powertools.utilities.data_classes.shared_functions import get_header_value
4+
5+
16
class _FrozenDict(dict):
27
"""
38
A dictionary that can be used as a key in another dictionary.
@@ -11,3 +16,24 @@ class _FrozenDict(dict):
1116

1217
def __hash__(self):
1318
return hash(frozenset(self.keys()))
19+
20+
21+
def extract_origin_header(resolver_headers: Dict[str, Any]):
22+
"""
23+
Extracts the 'origin' or 'Origin' header from the provided resolver headers.
24+
25+
The 'origin' or 'Origin' header can be either a single header or a multi-header.
26+
27+
Args:
28+
resolver_headers (Dict): A dictionary containing the headers.
29+
30+
Returns:
31+
Optional[str]: The value(s) of the origin header or None.
32+
"""
33+
resolved_header = get_header_value(resolver_headers, "origin", None, case_sensitive=False)
34+
if isinstance(resolved_header, str):
35+
return resolved_header
36+
if isinstance(resolved_header, list):
37+
return resolved_header[0]
38+
39+
return resolved_header

0 commit comments

Comments
 (0)