diff --git a/Makefile b/Makefile index 2cff4996889..f6ad229e1eb 100644 --- a/Makefile +++ b/Makefile @@ -84,7 +84,7 @@ complexity-baseline: $(info Maintenability index) poetry run radon mi aws_lambda_powertools $(info Cyclomatic complexity index) - poetry run xenon --max-absolute C --max-modules A --max-average A aws_lambda_powertools + poetry run xenon --max-absolute C --max-modules A --max-average A aws_lambda_powertools --exclude aws_lambda_powertools/shared/json_encoder.py # # Use `poetry version /` for version bump diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 0007540453d..31536457344 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -709,7 +709,7 @@ class ResponseBuilder(Generic[ResponseEventT]): def __init__( self, response: Response, - serializer: Callable[[Any], str] = json.dumps, + serializer: Callable[[Any], str] = partial(json.dumps, separators=(",", ":"), cls=Encoder), route: Optional[Route] = None, ): self.response = response diff --git a/aws_lambda_powertools/shared/functions.py b/aws_lambda_powertools/shared/functions.py index 82ea7dad8d8..fb36b98dc34 100644 --- a/aws_lambda_powertools/shared/functions.py +++ b/aws_lambda_powertools/shared/functions.py @@ -1,7 +1,6 @@ from __future__ import annotations import base64 -import dataclasses import itertools import logging import os @@ -168,8 +167,86 @@ def extract_event_from_common_models(data: Any) -> Dict | Any: return data.raw_event # Is it a Pydantic Model? - if callable(getattr(data, "dict", None)): - return data.dict() + if is_pydantic(data): + return pydantic_to_dict(data) - # Is it a Dataclass? If not return as is - return dataclasses.asdict(data) if dataclasses.is_dataclass(data) else data + # Is it a Dataclass? + if is_dataclass(data): + return dataclass_to_dict(data) + + # Return as is + return data + + +def is_pydantic(data) -> bool: + """Whether data is a Pydantic model by checking common field available in v1/v2 + + Parameters + ---------- + data: BaseModel + Pydantic model + + Returns + ------- + bool + Whether it's a Pydantic model + """ + return getattr(data, "json", False) + + +def is_dataclass(data) -> bool: + """Whether data is a dataclass + + Parameters + ---------- + data: dataclass + Dataclass obj + + Returns + ------- + bool + Whether it's a Dataclass + """ + return getattr(data, "__dataclass_fields__", False) + + +def pydantic_to_dict(data) -> dict: + """Dump Pydantic model v1 and v2 as dict. + + Note we use lazy import since Pydantic is an optional dependency. + + Parameters + ---------- + data: BaseModel + Pydantic model + + Returns + ------- + + dict: + Pydantic model serialized to dict + """ + from aws_lambda_powertools.event_handler.openapi.compat import _model_dump + + return _model_dump(data) + + +def dataclass_to_dict(data) -> dict: + """Dump standard dataclass as dict. + + Note we use lazy import to prevent bloating other code parts. + + Parameters + ---------- + data: dataclass + Dataclass + + Returns + ------- + + dict: + Pydantic model serialized to dict + """ + import dataclasses + + return dataclasses.asdict(data) diff --git a/aws_lambda_powertools/shared/json_encoder.py b/aws_lambda_powertools/shared/json_encoder.py index 32a094abd85..867745b2866 100644 --- a/aws_lambda_powertools/shared/json_encoder.py +++ b/aws_lambda_powertools/shared/json_encoder.py @@ -2,10 +2,13 @@ import json import math +from aws_lambda_powertools.shared.functions import dataclass_to_dict, is_dataclass, is_pydantic, pydantic_to_dict + class Encoder(json.JSONEncoder): - """ - Custom JSON encoder to allow for serialization of Decimals, similar to the serializer used by Lambda internally. + """Custom JSON encoder to allow for serialization of Decimals, Pydantic and Dataclasses. + + It's similar to the serializer used by Lambda internally. """ def default(self, obj): @@ -13,4 +16,11 @@ def default(self, obj): if obj.is_nan(): return math.nan return str(obj) + + if is_pydantic(obj): + return pydantic_to_dict(obj) + + if is_dataclass(obj): + return dataclass_to_dict(obj) + return super().default(obj) diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index d4c88b541aa..9c98faff062 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -10,6 +10,7 @@ from typing import Dict import pytest +from pydantic import BaseModel from aws_lambda_powertools.event_handler import content_types from aws_lambda_powertools.event_handler.api_gateway import ( @@ -1465,7 +1466,6 @@ def test_exception_handler_with_data_validation(): @app.exception_handler(RequestValidationError) def handle_validation_error(ex: RequestValidationError): - print(f"request path is '{app.current_event.path}'") return Response( status_code=422, content_type=content_types.TEXT_PLAIN, @@ -1486,6 +1486,34 @@ def get_lambda(param: int): assert result["body"] == "Invalid data. Number of errors: 1" +def test_exception_handler_with_data_validation_pydantic_response(): + # GIVEN a resolver with an exception handler defined for RequestValidationError + app = ApiGatewayResolver(enable_validation=True) + + class Err(BaseModel): + msg: str + + @app.exception_handler(RequestValidationError) + def handle_validation_error(ex: RequestValidationError): + return Response( + status_code=422, + content_type=content_types.APPLICATION_JSON, + body=Err(msg=f"Invalid data. Number of errors: {len(ex.errors())}"), + ) + + @app.get("/my/path") + def get_lambda(param: int): + ... + + # WHEN calling the event handler + # AND a RequestValidationError is raised + result = app(LOAD_GW_EVENT, {}) + + # THEN exception handler's pydantic response should be serialized correctly + assert result["statusCode"] == 422 + assert result["body"] == '{"msg":"Invalid data. Number of errors: 1"}' + + def test_data_validation_error(): # GIVEN a resolver without an exception handler app = ApiGatewayResolver(enable_validation=True) diff --git a/tests/unit/test_json_encoder.py b/tests/unit/test_json_encoder.py index af8de4257a8..0dad7634df5 100644 --- a/tests/unit/test_json_encoder.py +++ b/tests/unit/test_json_encoder.py @@ -1,7 +1,9 @@ import decimal import json +from dataclasses import dataclass import pytest +from pydantic import BaseModel from aws_lambda_powertools.shared.json_encoder import Encoder @@ -22,3 +24,35 @@ class CustomClass: with pytest.raises(TypeError): json.dumps({"val": CustomClass()}, cls=Encoder) + + +def test_json_encode_pydantic(): + # GIVEN a Pydantic model + class Model(BaseModel): + data: dict + + data = {"msg": "hello"} + model = Model(data=data) + + # WHEN json.dumps use our custom Encoder + result = json.dumps(model, cls=Encoder) + + # THEN we should serialize successfully; not raise a TypeError + assert result == json.dumps({"data": data}, cls=Encoder) + + +def test_json_encode_dataclasses(): + # GIVEN a standard dataclass + + @dataclass + class Model: + data: dict + + data = {"msg": "hello"} + model = Model(data=data) + + # WHEN json.dumps use our custom Encoder + result = json.dumps(model, cls=Encoder) + + # THEN we should serialize successfully; not raise a TypeError + assert result == json.dumps({"data": data}, cls=Encoder)