Skip to content

Commit c0622a5

Browse files
authored
fix(typing): improve overloads to ensure the return type follows the default_value type (aws-powertools#4114)
1 parent 32e733b commit c0622a5

File tree

8 files changed

+115
-10
lines changed

8 files changed

+115
-10
lines changed

Diff for: aws_lambda_powertools/utilities/data_classes/api_gateway_authorizer_event.py

+11
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,17 @@ def path_parameters(self) -> Optional[Dict[str, str]]:
283283
def stage_variables(self) -> Optional[Dict[str, str]]:
284284
return self.get("stageVariables")
285285

286+
@overload
287+
def get_header_value(self, name: str, default_value: str, case_sensitive: bool = False) -> str: ...
288+
289+
@overload
290+
def get_header_value(
291+
self,
292+
name: str,
293+
default_value: Optional[str] = None,
294+
case_sensitive: Optional[bool] = False,
295+
) -> Optional[str]: ...
296+
286297
def get_header_value(
287298
self,
288299
name: str,

Diff for: aws_lambda_powertools/utilities/data_classes/appsync_resolver_event.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, List, Optional, Union
1+
from typing import Any, Dict, List, Optional, Union, overload
22

33
from aws_lambda_powertools.utilities.data_classes.common import DictWrapper
44
from aws_lambda_powertools.utilities.data_classes.shared_functions import (
@@ -214,6 +214,22 @@ def stash(self) -> Optional[dict]:
214214
a pipeline resolver."""
215215
return self.get("stash")
216216

217+
@overload
218+
def get_header_value(
219+
self,
220+
name: str,
221+
default_value: str,
222+
case_sensitive: Optional[bool] = False,
223+
) -> str: ...
224+
225+
@overload
226+
def get_header_value(
227+
self,
228+
name: str,
229+
default_value: Optional[str] = None,
230+
case_sensitive: Optional[bool] = False,
231+
) -> Optional[str]: ...
232+
217233
def get_header_value(
218234
self,
219235
name: str,

Diff for: aws_lambda_powertools/utilities/data_classes/common.py

+6
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,12 @@ def http_method(self) -> str:
172172
"""The HTTP method used. Valid values include: DELETE, GET, HEAD, OPTIONS, PATCH, POST, and PUT."""
173173
return self["httpMethod"]
174174

175+
@overload
176+
def get_query_string_value(self, name: str, default_value: str) -> str: ...
177+
178+
@overload
179+
def get_query_string_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]: ...
180+
175181
def get_query_string_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]:
176182
"""Get query string value by name
177183

Diff for: aws_lambda_powertools/utilities/data_classes/kafka_event.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import base64
22
from functools import cached_property
3-
from typing import Any, Dict, Iterator, List, Optional
3+
from typing import Any, Dict, Iterator, List, Optional, overload
44

55
from aws_lambda_powertools.utilities.data_classes.common import DictWrapper
66
from aws_lambda_powertools.utilities.data_classes.shared_functions import (
@@ -69,10 +69,26 @@ def decoded_headers(self) -> Dict[str, bytes]:
6969
"""Decodes the headers as a single dictionary."""
7070
return {k: bytes(v) for chunk in self.headers for k, v in chunk.items()}
7171

72+
@overload
7273
def get_header_value(
7374
self,
7475
name: str,
75-
default_value: Optional[Any] = None,
76+
default_value: str,
77+
case_sensitive: bool = True,
78+
) -> str: ...
79+
80+
@overload
81+
def get_header_value(
82+
self,
83+
name: str,
84+
default_value: Optional[str] = None,
85+
case_sensitive: bool = True,
86+
) -> Optional[str]: ...
87+
88+
def get_header_value(
89+
self,
90+
name: str,
91+
default_value: Optional[str] = None,
7692
case_sensitive: bool = True,
7793
) -> Optional[str]:
7894
"""Get a decoded header value by name."""

Diff for: aws_lambda_powertools/utilities/data_classes/s3_object_event.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, Optional
1+
from typing import Dict, Optional, overload
22

33
from aws_lambda_powertools.utilities.data_classes.common import DictWrapper
44
from aws_lambda_powertools.utilities.data_classes.shared_functions import (
@@ -73,6 +73,22 @@ def headers(self) -> Dict[str, str]:
7373
The case of the original headers is retained in this map."""
7474
return self["headers"]
7575

76+
@overload
77+
def get_header_value(
78+
self,
79+
name: str,
80+
default_value: str,
81+
case_sensitive: Optional[bool] = False,
82+
) -> str: ...
83+
84+
@overload
85+
def get_header_value(
86+
self,
87+
name: str,
88+
default_value: Optional[str] = None,
89+
case_sensitive: Optional[bool] = False,
90+
) -> Optional[str]: ...
91+
7692
def get_header_value(
7793
self,
7894
name: str,

Diff for: aws_lambda_powertools/utilities/data_classes/shared_functions.py

+38-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import base64
4-
from typing import Any, Dict
4+
from typing import Any, Dict, overload
55

66

77
def base64_decode(value: str) -> str:
@@ -21,11 +21,29 @@ def base64_decode(value: str) -> str:
2121
return base64.b64decode(value).decode("UTF-8")
2222

2323

24+
@overload
2425
def get_header_value(
2526
headers: dict[str, Any],
2627
name: str,
27-
default_value: str | None,
28-
case_sensitive: bool | None,
28+
default_value: str,
29+
case_sensitive: bool | None = False,
30+
) -> str: ...
31+
32+
33+
@overload
34+
def get_header_value(
35+
headers: dict[str, Any],
36+
name: str,
37+
default_value: str | None = None,
38+
case_sensitive: bool | None = False,
39+
) -> str | None: ...
40+
41+
42+
def get_header_value(
43+
headers: dict[str, Any],
44+
name: str,
45+
default_value: str | None = None,
46+
case_sensitive: bool | None = False,
2947
) -> str | None:
3048
"""
3149
Get the value of a header by its name.
@@ -39,7 +57,7 @@ def get_header_value(
3957
default_value: str, optional
4058
The default value to return if the header is not found. Default is None.
4159
case_sensitive: bool, optional
42-
Indicates whether the header name should be case-sensitive. Default is None.
60+
Indicates whether the header name should be case-sensitive. Default is False.
4361
4462
Returns
4563
-------
@@ -62,6 +80,22 @@ def get_header_value(
6280
)
6381

6482

83+
@overload
84+
def get_query_string_value(
85+
query_string_parameters: Dict[str, str] | None,
86+
name: str,
87+
default_value: str,
88+
) -> str: ...
89+
90+
91+
@overload
92+
def get_query_string_value(
93+
query_string_parameters: Dict[str, str] | None,
94+
name: str,
95+
default_value: str | None = None,
96+
) -> str | None: ...
97+
98+
6599
def get_query_string_value(
66100
query_string_parameters: Dict[str, str] | None,
67101
name: str,

Diff for: aws_lambda_powertools/utilities/data_classes/vpc_lattice.py

+6
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@ def http_method(self) -> str:
4747
"""The HTTP method used. Valid values include: DELETE, GET, HEAD, OPTIONS, PATCH, POST, and PUT."""
4848
return self["method"]
4949

50+
@overload
51+
def get_query_string_value(self, name: str, default_value: str) -> str: ...
52+
53+
@overload
54+
def get_query_string_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]: ...
55+
5056
def get_query_string_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]:
5157
"""Get query string value by name
5258

Diff for: examples/event_handler_graphql/src/custom_models.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@ class Location(TypedDict, total=False):
2626
class MyCustomModel(AppSyncResolverEvent):
2727
@property
2828
def country_viewer(self) -> str:
29-
return self.get_header_value(name="cloudfront-viewer-country", default_value="", case_sensitive=False) # type: ignore[return-value] # sentinel typing # noqa: E501
29+
return self.get_header_value(name="cloudfront-viewer-country", default_value="", case_sensitive=False)
3030

3131
@property
3232
def api_key(self) -> str:
33-
return self.get_header_value(name="x-api-key", default_value="", case_sensitive=False) # type: ignore[return-value] # sentinel typing # noqa: E501
33+
return self.get_header_value(name="x-api-key", default_value="", case_sensitive=False)
3434

3535

3636
@app.resolver(type_name="Query", field_name="listLocations")

0 commit comments

Comments
 (0)