diff --git a/aws_lambda_powertools/event_handler/__init__.py b/aws_lambda_powertools/event_handler/__init__.py new file mode 100644 index 00000000000..0475982e377 --- /dev/null +++ b/aws_lambda_powertools/event_handler/__init__.py @@ -0,0 +1,7 @@ +""" +Event handler decorators for common Lambda events +""" + +from .appsync import AppSyncResolver + +__all__ = ["AppSyncResolver"] diff --git a/aws_lambda_powertools/event_handler/appsync.py b/aws_lambda_powertools/event_handler/appsync.py new file mode 100644 index 00000000000..021afaa6654 --- /dev/null +++ b/aws_lambda_powertools/event_handler/appsync.py @@ -0,0 +1,113 @@ +import logging +from typing import Any, Callable + +from aws_lambda_powertools.utilities.data_classes import AppSyncResolverEvent +from aws_lambda_powertools.utilities.typing import LambdaContext + +logger = logging.getLogger(__name__) + + +class AppSyncResolver: + """ + AppSync resolver decorator + + Example + ------- + + **Sample usage** + + from aws_lambda_powertools.event_handler import AppSyncResolver + + app = AppSyncResolver() + + @app.resolver(type_name="Query", field_name="listLocations") + def list_locations(page: int = 0, size: int = 10) -> list: + # Your logic to fetch locations with arguments passed in + return [{"id": 100, "name": "Smooth Grooves"}] + + @app.resolver(type_name="Merchant", field_name="extraInfo") + def get_extra_info() -> dict: + # Can use "app.current_event.source" to filter within the parent context + account_type = app.current_event.source["accountType"] + method = "BTC" if account_type == "NEW" else "USD" + return {"preferredPaymentMethod": method} + + @app.resolver(field_name="commonField") + def common_field() -> str: + # Would match all fieldNames matching 'commonField' + return str(uuid.uuid4()) + """ + + current_event: AppSyncResolverEvent + lambda_context: LambdaContext + + def __init__(self): + self._resolvers: dict = {} + + def resolver(self, type_name: str = "*", field_name: str = None): + """Registers the resolver for field_name + + Parameters + ---------- + type_name : str + Type name + field_name : str + Field name + """ + + def register_resolver(func): + logger.debug(f"Adding resolver `{func.__name__}` for field `{type_name}.{field_name}`") + self._resolvers[f"{type_name}.{field_name}"] = {"func": func} + return func + + return register_resolver + + def resolve(self, event: dict, context: LambdaContext) -> Any: + """Resolve field_name + + Parameters + ---------- + event : dict + Lambda event + context : LambdaContext + Lambda context + + Returns + ------- + Any + Returns the result of the resolver + + Raises + ------- + ValueError + If we could not find a field resolver + """ + self.current_event = AppSyncResolverEvent(event) + self.lambda_context = context + resolver = self._get_resolver(self.current_event.type_name, self.current_event.field_name) + return resolver(**self.current_event.arguments) + + def _get_resolver(self, type_name: str, field_name: str) -> Callable: + """Get resolver for field_name + + Parameters + ---------- + type_name : str + Type name + field_name : str + Field name + + Returns + ------- + Callable + callable function and configuration + """ + full_name = f"{type_name}.{field_name}" + resolver = self._resolvers.get(full_name, self._resolvers.get(f"*.{field_name}")) + if not resolver: + raise ValueError(f"No resolver found for '{full_name}'") + return resolver["func"] + + def __call__(self, event, context) -> Any: + """Implicit lambda handler which internally calls `resolve`""" + return self.resolve(event, context) diff --git a/aws_lambda_powertools/utilities/data_classes/__init__.py b/aws_lambda_powertools/utilities/data_classes/__init__.py index 28179bfd291..58464ebcf99 100644 --- a/aws_lambda_powertools/utilities/data_classes/__init__.py +++ b/aws_lambda_powertools/utilities/data_classes/__init__.py @@ -1,7 +1,10 @@ -from aws_lambda_powertools.utilities.data_classes.appsync_resolver_event import AppSyncResolverEvent +""" +Event Source Data Classes utility provides classes self-describing Lambda event sources. +""" from .alb_event import ALBEvent from .api_gateway_proxy_event import APIGatewayProxyEvent, APIGatewayProxyEventV2 +from .appsync_resolver_event import AppSyncResolverEvent from .cloud_watch_logs_event import CloudWatchLogsEvent from .connect_contact_flow_event import ConnectContactFlowEvent from .dynamo_db_stream_event import DynamoDBStreamEvent diff --git a/aws_lambda_powertools/utilities/data_classes/appsync/resolver_utils.py b/aws_lambda_powertools/utilities/data_classes/appsync/resolver_utils.py deleted file mode 100644 index 9329b8effe8..00000000000 --- a/aws_lambda_powertools/utilities/data_classes/appsync/resolver_utils.py +++ /dev/null @@ -1,50 +0,0 @@ -from typing import Any, Dict - -from aws_lambda_powertools.utilities.data_classes import AppSyncResolverEvent -from aws_lambda_powertools.utilities.typing import LambdaContext - - -class AppSyncResolver: - def __init__(self): - self._resolvers: dict = {} - - def resolver( - self, - type_name: str = "*", - field_name: str = None, - include_event: bool = False, - include_context: bool = False, - **kwargs, - ): - def register_resolver(func): - kwargs["include_event"] = include_event - kwargs["include_context"] = include_context - self._resolvers[f"{type_name}.{field_name}"] = { - "func": func, - "config": kwargs, - } - return func - - return register_resolver - - def resolve(self, _event: dict, context: LambdaContext) -> Any: - event = AppSyncResolverEvent(_event) - resolver, config = self._resolver(event.type_name, event.field_name) - kwargs = self._kwargs(event, context, config) - return resolver(**kwargs) - - def _resolver(self, type_name: str, field_name: str) -> tuple: - full_name = f"{type_name}.{field_name}" - resolver = self._resolvers.get(full_name, self._resolvers.get(f"*.{field_name}")) - if not resolver: - raise ValueError(f"No resolver found for '{full_name}'") - return resolver["func"], resolver["config"] - - @staticmethod - def _kwargs(event: AppSyncResolverEvent, context: LambdaContext, config: dict) -> Dict[str, Any]: - kwargs = {**event.arguments} - if config.get("include_event", False): - kwargs["event"] = event - if config.get("include_context", False): - kwargs["context"] = context - return kwargs diff --git a/docs/core/event_handler.md b/docs/core/event_handler.md new file mode 100644 index 00000000000..8de4e83c4ad --- /dev/null +++ b/docs/core/event_handler.md @@ -0,0 +1,223 @@ +--- +title: Event Handler +description: Utility +--- + +Event handler decorators for common Lambda events + + +### AppSync Resolver Decorator + +> New in 1.14.0 + +AppSync resolver decorator is a concise way to create lambda functions to handle AppSync resolvers for multiple +`typeName` and `fieldName` declarations. This decorator builds on top of the +[AppSync Resolver ](/utilities/data_classes#appsync-resolver) data class and therefore works with [Amplify GraphQL Transform Library](https://docs.amplify.aws/cli/graphql-transformer/function){target="_blank"} (`@function`), +and [AppSync Direct Lambda Resolvers](https://aws.amazon.com/blogs/mobile/appsync-direct-lambda/){target="_blank"} + +#### Key Features + +* Works with any of the existing Powertools utilities by allow you to create your own `lambda_handler` function +* Supports an implicit handler where in `app = AppSyncResolver()` can be invoked directly as `app(event, context)` +* `resolver` decorator has flexible or strict matching against `fieldName` +* Arguments are automatically passed into your function +* AppSyncResolver includes `current_event` and `lambda_cotext` fields can be used to pass in the original `AppSyncResolver` or `LambdaContext` + objects + +#### Amplify GraphQL Example + +Create a new GraphQL api via `amplify add api` and add the following to the new `schema.graphql` + +=== "schema.graphql" + + ```typescript hl_lines="7-10 17-18 22-25" + @model + type Merchant + { + id: String! + name: String! + description: String + # Resolves to `get_extra_info` + extraInfo: ExtraInfo @function(name: "merchantInfo-${env}") + # Resolves to `common_field` + commonField: String @function(name: "merchantInfo-${env}") + } + + type Location { + id: ID! + name: String! + address: Address + # Resolves to `common_field` + commonField: String @function(name: "merchantInfo-${env}") + } + + type Query { + # List of locations resolves to `list_locations` + listLocations(page: Int, size: Int): [Location] @function(name: "merchantInfo-${env}") + # List of locations resolves to `list_locations` + findMerchant(search: str): [Merchant] @function(name: "searchMerchant-${env}") + } + ``` + +Create two new simple Python functions via `amplify add function` and run `pipenv install aws-lambda-powertools` to +add Powertools as a dependency. Add the following example lambda implementation + +=== "merchantInfo/src/app.py" + + ```python hl_lines="1-2 6 8-9 13-14 18-19 24 26" + from aws_lambda_powertools.event_handler import AppSyncResolver + from aws_lambda_powertools.logging import Logger, Tracer, correlation_paths + + tracer = Tracer() + logger = Logger() + app = AppSyncResolver() + + @app.resolver(type_name="Query", field_name="listLocations") + def list_locations(page: int = 0, size: int = 10): + # Your logic to fetch locations + ... + + @app.resolver(type_name="Merchant", field_name="extraInfo") + def get_extra_info(): + # Can use `app.current_event.source["id"]` to filter within the Merchant context + ... + + @app.resolver(field_name="commonField") + def common_field(): + # Would match all fieldNames matching 'commonField' + ... + + @tracer.capture_lambda_handler + @logger.inject_lambda_context(correlation_id_path=correlation_paths.APPSYNC_RESOLVER) + def lambda_handler(event, context): + app.resolve(event, context) + ``` +=== "searchMerchant/src/app.py" + + ```python hl_lines="1 3 5-6" + from aws_lambda_powertools.event_handler import AppSyncResolver + + app = AppSyncResolver() + + @app.resolver(type_name="Query", field_name="findMerchant") + def find_merchant(search: str): + # Your special search function + ... + ``` + +Example AppSync resolver events + +=== "Query.listLocations event" + + ```json hl_lines="2-7" + { + "typeName": "Query", + "fieldName": "listLocations", + "arguments": { + "page": 2, + "size": 1 + }, + "identity": { + "claims": { + "iat": 1615366261 + ... + }, + "username": "mike", + ... + }, + "request": { + "headers": { + "x-amzn-trace-id": "Root=1-60488877-0b0c4e6727ab2a1c545babd0", + "x-forwarded-for": "127.0.0.1" + ... + } + }, + ... + } + ``` + +=== "Merchant.extraInfo event" + + ```json hl_lines="2-5 14-17" + { + "typeName": "Merchant", + "fieldName": "extraInfo", + "arguments": { + }, + "identity": { + "claims": { + "iat": 1615366261 + ... + }, + "username": "mike", + ... + }, + "source": { + "id": "12345", + "name: "Pizza Parlor" + }, + "request": { + "headers": { + "x-amzn-trace-id": "Root=1-60488877-0b0c4e6727ab2a1c545babd0", + "x-forwarded-for": "127.0.0.1" + ... + } + }, + ... + } + ``` + +=== "*.commonField event" + + ```json hl_lines="2 3" + { + "typeName": "Merchant", + "fieldName": "commonField", + "arguments": { + }, + "identity": { + "claims": { + "iat": 1615366261 + ... + }, + "username": "mike", + ... + }, + "request": { + "headers": { + "x-amzn-trace-id": "Root=1-60488877-0b0c4e6727ab2a1c545babd0", + "x-forwarded-for": "127.0.0.1" + ... + } + }, + ... + } + ``` + +=== "Query.findMerchant event" + + ```json hl_lines="2-6" + { + "typeName": "Query", + "fieldName": "findMerchant", + "arguments": { + "search": "Brewers Coffee" + }, + "identity": { + "claims": { + "iat": 1615366261 + ... + }, + "username": "mike", + ... + }, + "request": { + "headers": { + "x-amzn-trace-id": "Root=1-60488877-0b0c4e6727ab2a1c545babd0", + "x-forwarded-for": "127.0.0.1" + ... + } + }, + ... + } + ``` diff --git a/docs/index.md b/docs/index.md index 1f347b017e1..2e8a46cc3b8 100644 --- a/docs/index.md +++ b/docs/index.md @@ -144,6 +144,7 @@ aws serverlessrepo list-application-versions \ | [Tracing](./core/tracer) | Decorators and utilities to trace Lambda function handlers, and both synchronous and asynchronous functions | [Logger](./core/logger) | Structured logging made easier, and decorator to enrich structured logging with key Lambda context details | [Metrics](./core/metrics) | Custom Metrics created asynchronously via CloudWatch Embedded Metric Format (EMF) +| [Event handler](./core/event_handler) | Event handler decorators for common Lambda events | [Middleware factory](./utilities/middleware_factory) | Decorator factory to create your own middleware to run logic before, and after each Lambda invocation | [Parameters](./utilities/parameters) | Retrieve parameter values from AWS Systems Manager Parameter Store, AWS Secrets Manager, or Amazon DynamoDB, and cache them for a specific amount of time | [Batch processing](./utilities/batch) | Handle partial failures for AWS SQS batch processing diff --git a/docs/utilities/data_classes.md b/docs/utilities/data_classes.md index fa85c719243..dc56ed8ec41 100644 --- a/docs/utilities/data_classes.md +++ b/docs/utilities/data_classes.md @@ -3,8 +3,7 @@ title: Event Source Data Classes description: Utility --- -Event Source Data Classes utility provides classes self-describing Lambda event sources, including API decorators when -applicable. +Event Source Data Classes utility provides classes self-describing Lambda event sources. ## Key Features diff --git a/mkdocs.yml b/mkdocs.yml index d8d37830369..0aa95693354 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -9,6 +9,7 @@ nav: - core/tracer.md - core/logger.md - core/metrics.md + - core/event_handler.md - Utilities: - utilities/middleware_factory.md - utilities/parameters.md diff --git a/tests/functional/event_handler/__init__.py b/tests/functional/event_handler/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/functional/appsync/test_appsync_resolver_utils.py b/tests/functional/event_handler/test_appsync.py similarity index 52% rename from tests/functional/appsync/test_appsync_resolver_utils.py rename to tests/functional/event_handler/test_appsync.py index a1388a1fb5c..c72331c32f1 100644 --- a/tests/functional/appsync/test_appsync_resolver_utils.py +++ b/tests/functional/event_handler/test_appsync.py @@ -1,28 +1,18 @@ import asyncio -import datetime import json -import os import sys +from pathlib import Path import pytest +from aws_lambda_powertools.event_handler import AppSyncResolver from aws_lambda_powertools.utilities.data_classes import AppSyncResolverEvent -from aws_lambda_powertools.utilities.data_classes.appsync.resolver_utils import AppSyncResolver -from aws_lambda_powertools.utilities.data_classes.appsync.scalar_types_utils import ( - _formatted_time, - aws_date, - aws_datetime, - aws_time, - aws_timestamp, - make_id, -) from aws_lambda_powertools.utilities.typing import LambdaContext def load_event(file_name: str) -> dict: - full_file_name = os.path.dirname(os.path.realpath(__file__)) + "/../../events/" + file_name - with open(full_file_name) as fp: - return json.load(fp) + path = Path(str(Path(__file__).parent.parent.parent) + "/events/" + file_name) + return json.loads(path.read_text()) def test_direct_resolver(): @@ -31,15 +21,14 @@ def test_direct_resolver(): app = AppSyncResolver() - @app.resolver(field_name="createSomething", include_context=True) - def create_something(context, id: str): # noqa AA03 VNE003 - assert context == {} + @app.resolver(field_name="createSomething") + def create_something(id: str): # noqa AA03 VNE003 + assert app.lambda_context == {} return id - def handler(event, context): - return app.resolve(event, context) + # Call the implicit handler + result = app(mock_event, {}) - result = handler(mock_event, {}) assert result == "my identifier" @@ -49,14 +38,16 @@ def test_amplify_resolver(): app = AppSyncResolver() - @app.resolver(type_name="Merchant", field_name="locations", include_event=True) - def get_location(event: AppSyncResolverEvent, page: int, size: int, name: str): - assert event is not None + @app.resolver(type_name="Merchant", field_name="locations") + def get_location(page: int, size: int, name: str): + assert app.current_event is not None + assert isinstance(app.current_event, AppSyncResolverEvent) assert page == 2 assert size == 1 return name def handler(event, context): + # Call the explicit resolve function return app.resolve(event, context) result = handler(mock_event, {}) @@ -80,42 +71,6 @@ def no_params(): assert result == "no_params has no params" -def test_resolver_include_event(): - # GIVEN - app = AppSyncResolver() - - mock_event = {"typeName": "Query", "fieldName": "field", "arguments": {}} - - @app.resolver(field_name="field", include_event=True) - def get_value(event: AppSyncResolverEvent): - return event - - # WHEN - result = app.resolve(mock_event, LambdaContext()) - - # THEN - assert result._data == mock_event - assert isinstance(result, AppSyncResolverEvent) - - -def test_resolver_include_context(): - # GIVEN - app = AppSyncResolver() - - mock_event = {"typeName": "Query", "fieldName": "field", "arguments": {}} - - @app.resolver(field_name="field", include_context=True) - def get_value(context: LambdaContext): - return context - - # WHEN - mock_context = LambdaContext() - result = app.resolve(mock_event, mock_context) - - # THEN - assert result == mock_context - - def test_resolver_value_error(): # GIVEN no defined field resolver app = AppSyncResolver() @@ -189,46 +144,3 @@ async def get_async(): # THEN assert asyncio.run(result) == "value" - - -def test_make_id(): - uuid: str = make_id() - assert isinstance(uuid, str) - assert len(uuid) == 36 - - -def test_aws_date_utc(): - date_str = aws_date() - assert isinstance(date_str, str) - assert datetime.datetime.strptime(date_str, "%Y-%m-%dZ") - - -def test_aws_time_utc(): - time_str = aws_time() - assert isinstance(time_str, str) - assert datetime.datetime.strptime(time_str, "%H:%M:%SZ") - - -def test_aws_datetime_utc(): - datetime_str = aws_datetime() - assert isinstance(datetime_str, str) - assert datetime.datetime.strptime(datetime_str, "%Y-%m-%dT%H:%M:%SZ") - - -def test_aws_timestamp(): - timestamp = aws_timestamp() - assert isinstance(timestamp, int) - - -def test_format_time_positive(): - now = datetime.datetime(2022, 1, 22) - datetime_str = _formatted_time(now, "%Y-%m-%d", 8) - assert isinstance(datetime_str, str) - assert datetime_str == "2022-01-22+08:00:00" - - -def test_format_time_negative(): - now = datetime.datetime(2022, 1, 22, 14, 22, 33) - datetime_str = _formatted_time(now, "%H:%M:%S", -12) - assert isinstance(datetime_str, str) - assert datetime_str == "02:22:33-12:00:00" diff --git a/tests/functional/test_lambda_trigger_events.py b/tests/functional/test_data_classes.py similarity index 97% rename from tests/functional/test_lambda_trigger_events.py rename to tests/functional/test_data_classes.py index 73fc6057265..0221acc6853 100644 --- a/tests/functional/test_lambda_trigger_events.py +++ b/tests/functional/test_data_classes.py @@ -1,4 +1,5 @@ import base64 +import datetime import json import os from secrets import compare_digest @@ -17,6 +18,14 @@ SNSEvent, SQSEvent, ) +from aws_lambda_powertools.utilities.data_classes.appsync.scalar_types_utils import ( + _formatted_time, + aws_date, + aws_datetime, + aws_time, + aws_timestamp, + make_id, +) from aws_lambda_powertools.utilities.data_classes.appsync_resolver_event import ( AppSyncIdentityCognito, AppSyncIdentityIAM, @@ -1060,3 +1069,46 @@ def test_s3_object_event_temp_credentials(): assert session_attributes is not None assert session_attributes.mfa_authenticated == session_context["attributes"]["mfaAuthenticated"] assert session_attributes.creation_date == session_context["attributes"]["creationDate"] + + +def test_make_id(): + uuid: str = make_id() + assert isinstance(uuid, str) + assert len(uuid) == 36 + + +def test_aws_date_utc(): + date_str = aws_date() + assert isinstance(date_str, str) + assert datetime.datetime.strptime(date_str, "%Y-%m-%dZ") + + +def test_aws_time_utc(): + time_str = aws_time() + assert isinstance(time_str, str) + assert datetime.datetime.strptime(time_str, "%H:%M:%SZ") + + +def test_aws_datetime_utc(): + datetime_str = aws_datetime() + assert isinstance(datetime_str, str) + assert datetime.datetime.strptime(datetime_str, "%Y-%m-%dT%H:%M:%SZ") + + +def test_aws_timestamp(): + timestamp = aws_timestamp() + assert isinstance(timestamp, int) + + +def test_format_time_positive(): + now = datetime.datetime(2022, 1, 22) + datetime_str = _formatted_time(now, "%Y-%m-%d", 8) + assert isinstance(datetime_str, str) + assert datetime_str == "2022-01-22+08:00:00" + + +def test_format_time_negative(): + now = datetime.datetime(2022, 1, 22, 14, 22, 33) + datetime_str = _formatted_time(now, "%H:%M:%S", -12) + assert isinstance(datetime_str, str) + assert datetime_str == "02:22:33-12:00:00"