Skip to content

feat(event_source): add class APIGatewayAuthorizerResponseWebSocket #6058

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import warnings
from typing import Any, overload

from typing_extensions import deprecated
from typing_extensions import deprecated, override

from aws_lambda_powertools.utilities.data_classes.common import (
BaseRequestContext,
Expand All @@ -28,9 +28,10 @@ def __init__(
aws_account_id: str,
api_id: str,
stage: str,
http_method: str,
http_method: str | None,
resource: str,
partition: str = "aws",
is_websocket_authorizer: bool = False,
):
self.partition = partition
self.region = region
Expand All @@ -40,39 +41,54 @@ def __init__(
self.http_method = http_method
# Remove matching "/" from `resource`.
self.resource = resource.lstrip("/")
self.is_websocket_authorizer = is_websocket_authorizer

@property
def arn(self) -> str:
"""Build an arn from its parts
eg: arn:aws:execute-api:us-east-1:123456789012:abcdef123/test/GET/request"""
return (
f"arn:{self.partition}:execute-api:{self.region}:{self.aws_account_id}:{self.api_id}/{self.stage}/"
f"{self.http_method}/{self.resource}"
)
base_arn = f"arn:{self.partition}:execute-api:{self.region}:{self.aws_account_id}:{self.api_id}/{self.stage}"

if not self.is_websocket_authorizer:
return f"{base_arn}/{self.http_method}/{self.resource}"
else:
return f"{base_arn}/{self.resource}"


def parse_api_gateway_arn(arn: str) -> APIGatewayRouteArn:
def parse_api_gateway_arn(arn: str, is_websocket_authorizer: bool = False) -> APIGatewayRouteArn:
"""Parses a gateway route arn as a APIGatewayRouteArn class

Parameters
----------
arn : str
ARN string for a methodArn or a routeArn
is_websocket_authorizer: bool
If it's a API Gateway Websocket

Returns
-------
APIGatewayRouteArn
"""
arn_parts = arn.split(":")
api_gateway_arn_parts = arn_parts[5].split("/")

if not is_websocket_authorizer:
http_method = api_gateway_arn_parts[2]
resource = "/".join(api_gateway_arn_parts[3:]) if len(api_gateway_arn_parts) >= 4 else ""
else:
http_method = None
resource = "/".join(api_gateway_arn_parts[2:])

return APIGatewayRouteArn(
partition=arn_parts[1],
region=arn_parts[3],
aws_account_id=arn_parts[4],
api_id=api_gateway_arn_parts[0],
stage=api_gateway_arn_parts[1],
http_method=api_gateway_arn_parts[2],
http_method=http_method,
# conditional allow us to handle /path/{proxy+} resources, as their length changes.
resource="/".join(api_gateway_arn_parts[3:]) if len(api_gateway_arn_parts) >= 4 else "",
resource=resource,
is_websocket_authorizer=is_websocket_authorizer,
)


Expand Down Expand Up @@ -512,13 +528,14 @@ def _add_route(self, effect: str, http_method: str, resource: str, conditions: l
raise ValueError(f"Invalid resource path: {resource}. Path should match {self.path_regex}")

resource_arn = APIGatewayRouteArn(
self.region,
self.aws_account_id,
self.api_id,
self.stage,
http_method,
resource,
self.partition,
region=self.region,
aws_account_id=self.aws_account_id,
api_id=self.api_id,
stage=self.stage,
http_method=http_method,
resource=resource,
partition=self.partition,
is_websocket_authorizer=False,
).arn

route = {"resourceArn": resource_arn, "conditions": conditions}
Expand Down Expand Up @@ -617,3 +634,127 @@ def asdict(self) -> dict[str, Any]:
response["context"] = self.context

return response


class APIGatewayAuthorizerResponseWebSocket(APIGatewayAuthorizerResponse):
"""The IAM Policy Response required for API Gateway WebSocket APIs

Based on: - https://github.com/awslabs/aws-apigateway-lambda-authorizer-blueprints/blob/\
master/blueprints/python/api-gateway-authorizer-python.py

Documentation:
-------------
- https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-lambda-authorizer.html
- https://docs.aws.amazon.com/apigateway/latest/developerguide/api-gateway-lambda-authorizer-output.html
"""

@staticmethod
def from_route_arn(
arn: str,
principal_id: str,
context: dict | None = None,
usage_identifier_key: str | None = None,
) -> APIGatewayAuthorizerResponseWebSocket:
parsed_arn = parse_api_gateway_arn(arn, is_websocket_authorizer=True)
return APIGatewayAuthorizerResponseWebSocket(
principal_id,
parsed_arn.region,
parsed_arn.aws_account_id,
parsed_arn.api_id,
parsed_arn.stage,
context,
usage_identifier_key,
)

# Note: we need ignore[override] because we are removing the http_method field
@override
def _add_route(self, effect: str, resource: str, conditions: list[dict] | None = None): # type: ignore[override]
"""Adds a route to the internal lists of allowed or denied routes. Each object in
the internal list contains a resource ARN and a condition statement. The condition
statement can be null."""
resource_arn = APIGatewayRouteArn(
region=self.region,
aws_account_id=self.aws_account_id,
api_id=self.api_id,
stage=self.stage,
http_method=None,
resource=resource,
partition=self.partition,
is_websocket_authorizer=True,
).arn

route = {"resourceArn": resource_arn, "conditions": conditions}

if effect.lower() == "allow":
self._allow_routes.append(route)
else: # deny
self._deny_routes.append(route)

@override
def allow_all_routes(self):
"""Adds a '*' allow to the policy to authorize access to all methods of an API"""
self._add_route(effect="Allow", resource="*")

@override
def deny_all_routes(self):
"""Adds a '*' allow to the policy to deny access to all methods of an API"""

self._add_route(effect="Deny", resource="*")

# Note: we need ignore[override] because we are removing the http_method field
@override
def allow_route(self, resource: str, conditions: list[dict] | None = None): # type: ignore[override]
"""
Add an API Gateway Websocket method to the list of allowed methods for the policy.

This method adds an API Gateway Websocket method Resource path) to the list of
allowed methods for the policy. It optionally includes conditions for the policy statement.

Parameters
----------
resource : str
The API Gateway resource path to allow.
conditions : list[dict] | None, optional
A list of condition dictionaries to apply to the policy statement.
Default is None.

Notes
-----
For more information on AWS policy conditions, see:
https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_elements.html#Condition

Example
--------
>>> policy = APIGatewayAuthorizerResponseWebSocket(...)
>>> policy.allow_route("/api/users", [{"StringEquals": {"aws:RequestTag/Environment": "Production"}}])
"""
self._add_route(effect="Allow", resource=resource, conditions=conditions)

# Note: we need ignore[override] because we are removing the http_method field
@override
def deny_route(self, resource: str, conditions: list[dict] | None = None): # type: ignore[override]
"""
Add an API Gateway Websocket method to the list of allowed methods for the policy.

This method adds an API Gateway Websocket method Resource path) to the list of
denied methods for the policy. It optionally includes conditions for the policy statement.

Parameters
----------
resource : str
The API Gateway resource path to allow.
conditions : list[dict] | None, optional
A list of condition dictionaries to apply to the policy statement.
Default is None.

Notes
-----
For more information on AWS policy conditions, see:
https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_elements.html#Condition

Example
--------
>>> policy = APIGatewayAuthorizerResponseWebSocket(...)
>>> policy.deny_route("/api/users", [{"StringEquals": {"aws:RequestTag/Environment": "Production"}}])
"""
self._add_route(effect="Deny", resource=resource, conditions=conditions)
10 changes: 8 additions & 2 deletions docs/utilities/data_classes.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,18 @@ It is used for [API Gateway Rest API Lambda Authorizer payload](https://docs.aws

Use **`APIGatewayAuthorizerRequestEvent`** for type `REQUEST` and **`APIGatewayAuthorizerTokenEvent`** for type `TOKEN`.

=== "app.py"
=== "Rest APIs"

```python hl_lines="2-4 8"
```python hl_lines="2-4 8 18"
--8<-- "examples/event_sources/src/apigw_authorizer_request.py"
```

=== "WebSocket APIs"

```python hl_lines="2-4 8 18"
--8<-- "examples/event_sources/src/apigw_authorizer_request_websocket.py"
```

=== "API Gateway Authorizer Request Example Event"

```json hl_lines="3 11"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from aws_lambda_powertools.utilities.data_classes import event_source
from aws_lambda_powertools.utilities.data_classes.api_gateway_authorizer_event import (
APIGatewayAuthorizerRequestEvent,
APIGatewayAuthorizerResponseWebSocket,
)


@event_source(data_class=APIGatewayAuthorizerRequestEvent)
def lambda_handler(event: APIGatewayAuthorizerRequestEvent, context):
# Simple auth check (replace with your actual auth logic)
is_authorized = event.headers.get("HeaderAuth1") == "headerValue1"

if not is_authorized:
return {"principalId": "", "policyDocument": {"Version": "2012-10-17", "Statement": []}}

arn = event.parsed_arn

policy = APIGatewayAuthorizerResponseWebSocket(
principal_id="user",
context={"user": "example"},
region=arn.region,
aws_account_id=arn.aws_account_id,
api_id=arn.api_id,
stage=arn.stage,
)

policy.allow_all_routes()

return policy.asdict()
81 changes: 81 additions & 0 deletions tests/events/apiGatewayAuthorizerWebSocketEvent.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
{
"type":"REQUEST",
"methodArn":"arn:aws:execute-api:us-east-1:533568316194:c5jwxq709g/production/$connect",
"headers":{
"Authorization":"Leo",
"Connection":"upgrade",
"content-length":"0",
"Host":"c5jwxq709g.execute-api.us-east-1.amazonaws.com",
"Sec-WebSocket-Extensions":"permessage-deflate; client_max_window_bits",
"Sec-WebSocket-Version":"13",
"Upgrade":"websocket",
"X-Amzn-Trace-Id":"Root=1-6797b6d3-64f9c928577f3ac56f5368ce",
"X-Forwarded-For":"93.108.161.96",
"X-Forwarded-Port":"443",
"X-Forwarded-Proto":"https"
},
"multiValueHeaders":{
"Authorization":[
"Leo"
],
"Connection":[
"upgrade"
],
"content-length":[
"0"
],
"Host":[
"c5jwxq709g.execute-api.us-east-1.amazonaws.com"
],
"Sec-WebSocket-Extensions":[
"permessage-deflate; client_max_window_bits"
],
"Sec-WebSocket-Key":[
"CYZZrfNgEcgzzzwL44qytQ=="
],
"Sec-WebSocket-Version":[
"13"
],
"Upgrade":[
"websocket"
],
"X-Amzn-Trace-Id":[
"Root=1-6797b6d3-64f9c928577f3ac56f5368ce"
],
"X-Forwarded-For":[
"93.108.161.96"
],
"X-Forwarded-Port":[
"443"
],
"X-Forwarded-Proto":[
"https"
]
},
"queryStringParameters":{

},
"multiValueQueryStringParameters":{

},
"stageVariables":{

},
"requestContext":{
"routeKey":"$connect",
"eventType":"CONNECT",
"extendedRequestId":"FDmBIG3EoAMEqYA=",
"requestTime":"27/Jan/2025:16:39:47 +0000",
"messageDirection":"IN",
"stage":"production",
"connectedAt":1737995987617,
"requestTimeEpoch":1737995987617,
"identity":{
"sourceIp":"93.108.161.96"
},
"requestId":"FDmBIG3EoAMEqYA=",
"domainName":"c5jwxq709g.execute-api.us-east-1.amazonaws.com",
"connectionId":"FDmBIeapIAMCIQg=",
"apiId":"c5jwxq709g"
}
}
Loading
Loading