diff --git a/aws_lambda_powertools/utilities/parser/__init__.py b/aws_lambda_powertools/utilities/parser/__init__.py index 1bc67934b13..f0faefc3ea4 100644 --- a/aws_lambda_powertools/utilities/parser/__init__.py +++ b/aws_lambda_powertools/utilities/parser/__init__.py @@ -2,12 +2,13 @@ """ from . import envelopes from .envelopes import BaseEnvelope -from .parser import event_parser, parse +from .parser import chained_parse, event_parser, parse from .pydantic import BaseModel, Field, ValidationError, root_validator, validator __all__ = [ "event_parser", "parse", + "chained_parse", "envelopes", "BaseEnvelope", "BaseModel", diff --git a/aws_lambda_powertools/utilities/parser/compat.py b/aws_lambda_powertools/utilities/parser/compat.py index c73098421b1..2cc81feb57d 100644 --- a/aws_lambda_powertools/utilities/parser/compat.py +++ b/aws_lambda_powertools/utilities/parser/compat.py @@ -1,6 +1,14 @@ import functools +@functools.lru_cache(maxsize=None) +def pydantic_version() -> int: + from pydantic import __version__ + + version = __version__.split(".") + return int(version[0]) + + @functools.lru_cache(maxsize=None) def disable_pydantic_v2_warning(): """ diff --git a/aws_lambda_powertools/utilities/parser/exceptions.py b/aws_lambda_powertools/utilities/parser/exceptions.py index 0df217e8522..1bbf82474ff 100644 --- a/aws_lambda_powertools/utilities/parser/exceptions.py +++ b/aws_lambda_powertools/utilities/parser/exceptions.py @@ -4,3 +4,7 @@ class InvalidEnvelopeError(Exception): class InvalidModelTypeError(Exception): """Input data model does not implement BaseModel""" + + +class InvalidEnvelopeChaining(Exception): + """Input Envelopes combination does not support chaining""" diff --git a/aws_lambda_powertools/utilities/parser/models/raw_event.py b/aws_lambda_powertools/utilities/parser/models/raw_event.py new file mode 100644 index 00000000000..b401b1021fd --- /dev/null +++ b/aws_lambda_powertools/utilities/parser/models/raw_event.py @@ -0,0 +1,21 @@ +from typing import Any, Dict + +from aws_lambda_powertools.utilities.parser.compat import pydantic_version + +if pydantic_version() == 1: + from pydantic import BaseModel + + class RawEvent(BaseModel): + __root__: Dict[str, Any] + + def as_raw_dict(self) -> Dict[str, Any]: + return self.__root__ + +else: + from pydantic import RootModel # type: ignore[attr-defined] + + class RawEvent(RootModel): # type: ignore[no-redef] + root: Dict[str, Any] + + def as_raw_dict(self) -> Dict[str, Any]: + return self.root diff --git a/aws_lambda_powertools/utilities/parser/parser.py b/aws_lambda_powertools/utilities/parser/parser.py index 7e2d69e429c..a8637b092bb 100644 --- a/aws_lambda_powertools/utilities/parser/parser.py +++ b/aws_lambda_powertools/utilities/parser/parser.py @@ -1,13 +1,15 @@ +import itertools import logging -from typing import Any, Callable, Dict, Optional, Type, overload +from typing import Any, Callable, Dict, List, Optional, Type, Union, overload from aws_lambda_powertools.utilities.parser.compat import disable_pydantic_v2_warning +from aws_lambda_powertools.utilities.parser.models.raw_event import RawEvent from aws_lambda_powertools.utilities.parser.types import EventParserReturnType, Model from ...middleware_factory import lambda_handler_decorator from ..typing import LambdaContext from .envelopes.base import Envelope -from .exceptions import InvalidEnvelopeError, InvalidModelTypeError +from .exceptions import InvalidEnvelopeChaining, InvalidEnvelopeError, InvalidModelTypeError logger = logging.getLogger(__name__) @@ -18,7 +20,7 @@ def event_parser( event: Dict[str, Any], context: LambdaContext, model: Type[Model], - envelope: Optional[Type[Envelope]] = None, + envelope: Optional[Union[Type[Envelope], List[Type[Envelope]]]] = None, ) -> EventParserReturnType: """Lambda handler decorator to parse & validate events using Pydantic models @@ -80,7 +82,14 @@ def handler(event: Order, context: LambdaContext): InvalidEnvelopeError When envelope given does not implement BaseEnvelope """ - parsed_event = parse(event=event, model=model, envelope=envelope) if envelope else parse(event=event, model=model) + parsed_event: Union[Model, List[Any], Any] = None + if not envelope: + parsed_event = parse(event=event, model=model) + elif isinstance(envelope, List): + parsed_event = chained_parse(event=event, model=model, envelopes=envelope) + else: + parsed_event = parse(event=event, model=model, envelope=envelope) + logger.debug(f"Calling handler {handler.__name__}") return handler(parsed_event, context) @@ -165,3 +174,34 @@ def handler(event: Order, context: LambdaContext): return model.parse_obj(event) except AttributeError: raise InvalidModelTypeError(f"Input model must implement BaseModel, model={model}") + + +def _chained_parse(events: List[Dict[str, Any]], model: Type[Model], envelopes: List[Type[Envelope]]) -> List: + print(type(envelopes)) + if len(envelopes) == 1: + envelope = envelopes[0] + print(f"{events=}, {model=}, {envelope=}") + res = [parse(event=event, model=model, envelope=envelope) for event in events] + if isinstance(res[0], List): + return list(itertools.chain.from_iterable(res)) + return res + + envelope = envelopes[0] + dict_events = [] + for event in events: + parsed_event: Union[RawEvent, List[RawEvent]] = parse(event=event, model=RawEvent, envelope=envelope) + if isinstance(parsed_event, RawEvent): + dict_events.append(parsed_event.as_raw_dict()) + elif isinstance(parsed_event, List) and isinstance(parsed_event[0], RawEvent): + dict_events.extend(x.as_raw_dict() for x in parsed_event) + else: + raise InvalidEnvelopeChaining( + f"Return type expected is {RawEvent} or {List[RawEvent]}, " + f"received {type(parsed_event)} from envelope {envelope}", + ) + + return list(itertools.chain.from_iterable(_chained_parse(events=dict_events, model=model, envelopes=envelopes[1:]))) + + +def chained_parse(event: Dict[str, Any], model: Type[Model], envelopes: List[Type[Envelope]]) -> List: + return _chained_parse(events=[event], model=model, envelopes=envelopes) diff --git a/tests/unit/parser/test_sns.py b/tests/unit/parser/test_sns.py index 9b925d5fa76..2147709e9d7 100644 --- a/tests/unit/parser/test_sns.py +++ b/tests/unit/parser/test_sns.py @@ -1,8 +1,9 @@ import json +from typing import List import pytest -from aws_lambda_powertools.utilities.parser import ValidationError, envelopes, parse +from aws_lambda_powertools.utilities.parser import ValidationError, chained_parse, envelopes, parse from tests.functional.utils import load_event from tests.functional.validator.conftest import sns_event # noqa: F401 from tests.unit.parser.schemas import MyAdvancedSnsBusiness, MySnsBusiness @@ -96,6 +97,19 @@ def test_handle_sns_sqs_trigger_event_json_body(): # noqa: F811 assert parsed_event[0].username == "lessa" +def test_handle_sns_sqs_trigger_event_chained(): # noqa: F811 + raw_event = load_event("snsSqsEvent.json") + parsed_event: List[MySnsBusiness] = chained_parse( + event=raw_event, + model=MySnsBusiness, + envelopes=[envelopes.SnsSqsEnvelope], + ) + + assert len(parsed_event) == 1 + assert parsed_event[0].message == "hello world" + assert parsed_event[0].username == "lessa" + + def test_handle_sns_sqs_trigger_event_json_body_missing_unsubscribe_url(): # GIVEN an event is tampered with a missing UnsubscribeURL raw_event = load_event("snsSqsEvent.json")