Skip to content

fix(apigateway): allow list of HTTP methods in route method #838

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 18 additions & 15 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from enum import Enum
from functools import partial
from http import HTTPStatus
from typing import Any, Callable, Dict, List, Optional, Set, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union

from aws_lambda_powertools.event_handler import content_types
from aws_lambda_powertools.event_handler.exceptions import ServiceError
Expand Down Expand Up @@ -453,27 +453,30 @@ def __init__(
def route(
self,
rule: str,
method: str,
method: Union[str, Union[List[str], Tuple[str]]],
cors: Optional[bool] = None,
compress: bool = False,
cache_control: Optional[str] = None,
):
"""Route decorator includes parameter `method`"""

def register_resolver(func: Callable):
logger.debug(f"Adding route using rule {rule} and method {method.upper()}")
methods = (method,) if isinstance(method, str) else method
logger.debug(f"Adding route using rule {rule} and methods: {','.join((m.upper() for m in methods))}")
if cors is None:
cors_enabled = self._cors_enabled
else:
cors_enabled = cors
self._routes.append(Route(method, self._compile_regex(rule), func, cors_enabled, compress, cache_control))
route_key = method + rule
if route_key in self._route_keys:
warnings.warn(f"A route like this was already registered. method: '{method}' rule: '{rule}'")
self._route_keys.append(route_key)
if cors_enabled:
logger.debug(f"Registering method {method.upper()} to Allow Methods in CORS")
self._cors_methods.add(method.upper())

for item in methods:
self._routes.append(Route(item, self._compile_regex(rule), func, cors_enabled, compress, cache_control))
route_key = item + rule
if route_key in self._route_keys:
warnings.warn(f"A route like this was already registered. method: '{item}' rule: '{rule}'")
self._route_keys.append(route_key)
if cors_enabled:
logger.debug(f"Registering method {item.upper()} to Allow Methods in CORS")
self._cors_methods.add(item.upper())
return func

return register_resolver
Expand Down Expand Up @@ -679,14 +682,14 @@ def __init__(self):
def route(
self,
rule: str,
method: Union[str, List[str]],
method: Union[str, Union[List[str], Tuple[str]]],
cors: Optional[bool] = None,
compress: bool = False,
cache_control: Optional[str] = None,
):
def register_route(func: Callable):
methods = method if isinstance(method, list) else [method]
for item in methods:
self._routes[(rule, item, cors, compress, cache_control)] = func
# Convert methods to tuple. It needs to be hashable as its part of the self._routes dict key
methods = (method,) if isinstance(method, str) else tuple(method)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

self._routes[(rule, methods, cors, compress, cache_control)] = func

return register_route
119 changes: 91 additions & 28 deletions docs/core/event_handler/api_gateway.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,45 +42,27 @@ This is the sample infrastructure for API Gateway we are using for the examples
Timeout: 5
Runtime: python3.8
Tracing: Active
Environment:
Environment:
Variables:
LOG_LEVEL: INFO
POWERTOOLS_LOGGER_SAMPLE_RATE: 0.1
POWERTOOLS_LOGGER_LOG_EVENT: true
POWERTOOLS_METRICS_NAMESPACE: MyServerlessApplication
POWERTOOLS_SERVICE_NAME: hello
POWERTOOLS_SERVICE_NAME: my_api-service

Resources:
HelloWorldFunction:
ApiFunction:
Type: AWS::Serverless::Function
Properties:
Handler: app.lambda_handler
CodeUri: hello_world
Description: Hello World function
CodeUri: api_handler/
Description: API handler function
Events:
HelloUniverse:
Type: Api
Properties:
Path: /hello
Method: GET
HelloYou:
Type: Api
Properties:
Path: /hello/{name} # see Dynamic routes section
Method: GET
CustomMessage:
Type: Api
Properties:
Path: /{message}/{name} # see Dynamic routes section
Method: GET

Outputs:
HelloWorldApigwURL:
Description: "API Gateway endpoint URL for Prod environment for Hello World Function"
Value: !Sub "https://${ServerlessRestApi}.execute-api.${AWS::Region}.amazonaws.com/Prod/hello"
HelloWorldFunction:
Description: "Hello World Lambda Function ARN"
Value: !GetAtt HelloWorldFunction.Arn
ApiEvent:
Type: Api
Properties:
Path: /{proxy+} # Send requests on any path to the lambda function
Method: ANY # Send requests using any http method to the lambda function
```

### API Gateway decorator
Expand Down Expand Up @@ -360,6 +342,87 @@ You can also combine nested paths with greedy regex to catch in between routes.
...
}
```
### HTTP Methods
You can use named decorators to specify the HTTP method that should be handled in your functions. As well as the
`get` method already shown above, you can use `post`, `put`, `patch`, `delete`, and `patch`.

=== "app.py"

```python hl_lines="9-10"
from aws_lambda_powertools import Logger, Tracer
from aws_lambda_powertools.logging import correlation_paths
from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver

tracer = Tracer()
logger = Logger()
app = ApiGatewayResolver()

# Only POST HTTP requests to the path /hello will route to this function
@app.post("/hello")
@tracer.capture_method
def get_hello_you():
name = app.current_event.json_body.get("name")
return {"message": f"hello {name}"}

# You can continue to use other utilities just as before
@logger.inject_lambda_context(correlation_id_path=correlation_paths.API_GATEWAY_REST)
@tracer.capture_lambda_handler
def lambda_handler(event, context):
return app.resolve(event, context)
```

=== "sample_request.json"

```json
{
"resource": "/hello/{name}",
"path": "/hello/lessa",
"httpMethod": "GET",
...
}
```

If you need to accept multiple HTTP methods in a single function, you can use the `route` method and pass a list of
HTTP methods.

=== "app.py"

```python hl_lines="9-10"
from aws_lambda_powertools import Logger, Tracer
from aws_lambda_powertools.logging import correlation_paths
from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver

tracer = Tracer()
logger = Logger()
app = ApiGatewayResolver()

# PUT and POST HTTP requests to the path /hello will route to this function
@app.route("/hello", method=["PUT", "POST"])
@tracer.capture_method
def get_hello_you():
name = app.current_event.json_body.get("name")
return {"message": f"hello {name}"}

# You can continue to use other utilities just as before
@logger.inject_lambda_context(correlation_id_path=correlation_paths.API_GATEWAY_REST)
@tracer.capture_lambda_handler
def lambda_handler(event, context):
return app.resolve(event, context)
```

=== "sample_request.json"

```json
{
"resource": "/hello/{name}",
"path": "/hello/lessa",
"httpMethod": "GET",
...
}
```

!!! note "It is usually better to have separate functions for each HTTP method, as the functionality tends to differ
depending on which method is used."

### Accessing request details

Expand Down
36 changes: 36 additions & 0 deletions tests/functional/event_handler/test_api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,3 +1021,39 @@ def get_func_another_duplicate():
# THEN only execute the first registered route
# AND print warnings
assert result["statusCode"] == 200


def test_route_multiple_methods():
# GIVEN a function with http methods passed as a list
app = ApiGatewayResolver()
req = "foo"
get_event = deepcopy(LOAD_GW_EVENT)
get_event["resource"] = "/accounts/{account_id}"
get_event["path"] = f"/accounts/{req}"

post_event = deepcopy(get_event)
post_event["httpMethod"] = "POST"

put_event = deepcopy(get_event)
put_event["httpMethod"] = "PUT"

lambda_context = {}

@app.route(rule="/accounts/<account_id>", method=["GET", "POST"])
def foo(account_id):
assert app.lambda_context == lambda_context
assert account_id == f"{req}"
return {}

# WHEN calling the event handler with the supplied methods
get_result = app(get_event, lambda_context)
post_result = app(post_event, lambda_context)
put_result = app(put_event, lambda_context)

# THEN events are processed correctly
assert get_result["statusCode"] == 200
assert get_result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
assert post_result["statusCode"] == 200
assert post_result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
assert put_result["statusCode"] == 404
assert put_result["headers"]["Content-Type"] == content_types.APPLICATION_JSON