Skip to content

refactor(event_handler): add from __future__ import annotations in the Middlewares #4975

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions aws_lambda_powertools/event_handler/middlewares/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Generic, Protocol
from typing import TYPE_CHECKING, Generic, Protocol

from aws_lambda_powertools.event_handler.api_gateway import Response
from aws_lambda_powertools.event_handler.types import EventHandlerInstance

if TYPE_CHECKING:
from aws_lambda_powertools.event_handler.api_gateway import Response


class NextMiddleware(Protocol):
def __call__(self, app: EventHandlerInstance) -> Response:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from __future__ import annotations

import dataclasses
import json
import logging
from copy import deepcopy
from typing import Any, Callable, Dict, List, Mapping, MutableMapping, Optional, Sequence, Tuple
from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Sequence

from pydantic import BaseModel

from aws_lambda_powertools.event_handler import Response
from aws_lambda_powertools.event_handler.api_gateway import Route
from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware
from aws_lambda_powertools.event_handler.openapi.compat import (
ModelField,
Expand All @@ -20,8 +20,12 @@
from aws_lambda_powertools.event_handler.openapi.encoders import jsonable_encoder
from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError
from aws_lambda_powertools.event_handler.openapi.params import Param
from aws_lambda_powertools.event_handler.openapi.types import IncEx
from aws_lambda_powertools.event_handler.types import EventHandlerInstance

if TYPE_CHECKING:
from aws_lambda_powertools.event_handler import Response
from aws_lambda_powertools.event_handler.api_gateway import Route
from aws_lambda_powertools.event_handler.openapi.types import IncEx
from aws_lambda_powertools.event_handler.types import EventHandlerInstance

logger = logging.getLogger(__name__)

Expand All @@ -36,8 +40,6 @@ class OpenAPIValidationMiddleware(BaseMiddlewareHandler):
--------

```python
from typing import List

from pydantic import BaseModel

from aws_lambda_powertools.event_handler.api_gateway import (
Expand All @@ -50,12 +52,12 @@ class Todo(BaseModel):
app = APIGatewayRestResolver(enable_validation=True)

@app.get("/todos")
def get_todos(): List[Todo]:
def get_todos(): list[Todo]:
return [Todo(name="hello world")]
```
"""

def __init__(self, validation_serializer: Optional[Callable[[Any], str]] = None):
def __init__(self, validation_serializer: Callable[[Any], str] | None = None):
"""
Initialize the OpenAPIValidationMiddleware.

Expand All @@ -72,8 +74,8 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->

route: Route = app.context["_route"]

values: Dict[str, Any] = {}
errors: List[Any] = []
values: dict[str, Any] = {}
errors: list[Any] = []

# Process path values, which can be found on the route_args
path_values, path_errors = _request_params_to_args(
Expand Down Expand Up @@ -147,10 +149,10 @@ def _handle_response(self, *, route: Route, response: Response):
def _serialize_response(
self,
*,
field: Optional[ModelField] = None,
field: ModelField | None = None,
response_content: Any,
include: Optional[IncEx] = None,
exclude: Optional[IncEx] = None,
include: IncEx | None = None,
exclude: IncEx | None = None,
by_alias: bool = True,
exclude_unset: bool = False,
exclude_defaults: bool = False,
Expand All @@ -160,7 +162,7 @@ def _serialize_response(
Serialize the response content according to the field type.
"""
if field:
errors: List[Dict[str, Any]] = []
errors: list[dict[str, Any]] = []
# MAINTENANCE: remove this when we drop pydantic v1
if not hasattr(field, "serializable"):
response_content = self._prepare_response_content(
Expand Down Expand Up @@ -232,7 +234,7 @@ def _prepare_response_content(
return dataclasses.asdict(res)
return res

def _get_body(self, app: EventHandlerInstance) -> Dict[str, Any]:
def _get_body(self, app: EventHandlerInstance) -> dict[str, Any]:
"""
Get the request body from the event, and parse it as JSON.
"""
Expand Down Expand Up @@ -261,7 +263,7 @@ def _get_body(self, app: EventHandlerInstance) -> Dict[str, Any]:
def _request_params_to_args(
required_params: Sequence[ModelField],
received_params: Mapping[str, Any],
) -> Tuple[Dict[str, Any], List[Any]]:
) -> tuple[dict[str, Any], list[Any]]:
"""
Convert the request params to a dictionary of values using validation, and returns a list of errors.
"""
Expand Down Expand Up @@ -294,14 +296,14 @@ def _request_params_to_args(


def _request_body_to_args(
required_params: List[ModelField],
received_body: Optional[Dict[str, Any]],
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
required_params: list[ModelField],
received_body: dict[str, Any] | None,
) -> tuple[dict[str, Any], list[dict[str, Any]]]:
"""
Convert the request body to a dictionary of values using validation, and returns a list of errors.
"""
values: Dict[str, Any] = {}
errors: List[Dict[str, Any]] = []
values: dict[str, Any] = {}
errors: list[dict[str, Any]] = []

received_body, field_alias_omitted = _get_embed_body(
field=required_params[0],
Expand All @@ -313,11 +315,11 @@ def _request_body_to_args(
# This sets the location to:
# { "user": { object } } if field.alias == user
# { { object } if field_alias is omitted
loc: Tuple[str, ...] = ("body", field.alias)
loc: tuple[str, ...] = ("body", field.alias)
if field_alias_omitted:
loc = ("body",)

value: Optional[Any] = None
value: Any | None = None

# Now that we know what to look for, try to get the value from the received body
if received_body is not None:
Expand Down Expand Up @@ -347,8 +349,8 @@ def _validate_field(
*,
field: ModelField,
value: Any,
loc: Tuple[str, ...],
existing_errors: List[Dict[str, Any]],
loc: tuple[str, ...],
existing_errors: list[dict[str, Any]],
):
"""
Validate a field, and append any errors to the existing_errors list.
Expand All @@ -367,9 +369,9 @@ def _validate_field(
def _get_embed_body(
*,
field: ModelField,
required_params: List[ModelField],
received_body: Optional[Dict[str, Any]],
) -> Tuple[Optional[Dict[str, Any]], bool]:
required_params: list[ModelField],
received_body: dict[str, Any] | None,
) -> tuple[dict[str, Any] | None, bool]:
field_info = field.field_info
embed = getattr(field_info, "embed", None)

Expand All @@ -382,15 +384,15 @@ def _get_embed_body(


def _normalize_multi_query_string_with_param(
query_string: Dict[str, List[str]],
query_string: dict[str, list[str]],
params: Sequence[ModelField],
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
Extract and normalize resolved_query_string_parameters

Parameters
----------
query_string: Dict
query_string: dict
A dictionary containing the initial query string parameters.
params: Sequence[ModelField]
A sequence of ModelField objects representing parameters.
Expand All @@ -399,7 +401,7 @@ def _normalize_multi_query_string_with_param(
-------
A dictionary containing the processed multi_query_string_parameters.
"""
resolved_query_string: Dict[str, Any] = query_string
resolved_query_string: dict[str, Any] = query_string
for param in filter(is_scalar_field, params):
try:
# if the target parameter is a scalar, we keep the first value of the query string
Expand All @@ -416,7 +418,7 @@ def _normalize_multi_header_values_with_param(headers: MutableMapping[str, Any],

Parameters
----------
headers: Dict
headers: MutableMapping[str, Any]
A dictionary containing the initial header parameters.
params: Sequence[ModelField]
A sequence of ModelField objects representing parameters.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from __future__ import annotations

import logging
from typing import Dict, Optional
from typing import TYPE_CHECKING

from aws_lambda_powertools.event_handler.api_gateway import Response
from aws_lambda_powertools.event_handler.exceptions import BadRequestError, InternalServerError
from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware
from aws_lambda_powertools.event_handler.types import EventHandlerInstance
from aws_lambda_powertools.utilities.validation import validate
from aws_lambda_powertools.utilities.validation.exceptions import InvalidSchemaFormatError, SchemaValidationError

if TYPE_CHECKING:
from aws_lambda_powertools.event_handler.api_gateway import Response
from aws_lambda_powertools.event_handler.types import EventHandlerInstance

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -48,21 +52,21 @@ def lambda_handler(event, context):

def __init__(
self,
inbound_schema: Dict,
inbound_formats: Optional[Dict] = None,
outbound_schema: Optional[Dict] = None,
outbound_formats: Optional[Dict] = None,
inbound_schema: dict,
inbound_formats: dict | None = None,
outbound_schema: dict | None = None,
outbound_formats: dict | None = None,
):
"""See [Validation utility](https://docs.powertools.aws.dev/lambda/python/latest/utilities/validation/) docs for examples on all parameters.

Parameters
----------
inbound_schema : Dict
inbound_schema : dict
JSON Schema to validate incoming event
inbound_formats : Optional[Dict], optional
inbound_formats : dict | None, optional
Custom formats containing a key (e.g. int64) and a value expressed as regex or callback returning bool, by default None
JSON Schema to validate outbound event, by default None
outbound_formats : Optional[Dict], optional
outbound_formats : dict | None, optional
Custom formats containing a key (e.g. int64) and a value expressed as regex or callback returning bool, by default None
""" # noqa: E501
super().__init__()
Expand Down
Loading