Skip to content

Commit 55713ce

Browse files
authored
feat(event_handler): add support for OpenAPI security schemes (#4103)
1 parent 1e7b3ab commit 55713ce

File tree

15 files changed

+862
-37
lines changed

15 files changed

+862
-37
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

+125-24
Large diffs are not rendered by default.

aws_lambda_powertools/event_handler/bedrock_agent.py

+15
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ def get( # type: ignore[override]
102102
include_in_schema: bool = True,
103103
middlewares: Optional[List[Callable[..., Any]]] = None,
104104
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
105+
security = None
106+
105107
return super(BedrockAgentResolver, self).get(
106108
rule,
107109
cors,
@@ -114,6 +116,7 @@ def get( # type: ignore[override]
114116
tags,
115117
operation_id,
116118
include_in_schema,
119+
security,
117120
middlewares,
118121
)
119122

@@ -134,6 +137,8 @@ def post( # type: ignore[override]
134137
include_in_schema: bool = True,
135138
middlewares: Optional[List[Callable[..., Any]]] = None,
136139
):
140+
security = None
141+
137142
return super().post(
138143
rule,
139144
cors,
@@ -146,6 +151,7 @@ def post( # type: ignore[override]
146151
tags,
147152
operation_id,
148153
include_in_schema,
154+
security,
149155
middlewares,
150156
)
151157

@@ -166,6 +172,8 @@ def put( # type: ignore[override]
166172
include_in_schema: bool = True,
167173
middlewares: Optional[List[Callable[..., Any]]] = None,
168174
):
175+
security = None
176+
169177
return super().put(
170178
rule,
171179
cors,
@@ -178,6 +186,7 @@ def put( # type: ignore[override]
178186
tags,
179187
operation_id,
180188
include_in_schema,
189+
security,
181190
middlewares,
182191
)
183192

@@ -198,6 +207,8 @@ def patch( # type: ignore[override]
198207
include_in_schema: bool = True,
199208
middlewares: Optional[List[Callable]] = None,
200209
):
210+
security = None
211+
201212
return super().patch(
202213
rule,
203214
cors,
@@ -210,6 +221,7 @@ def patch( # type: ignore[override]
210221
tags,
211222
operation_id,
212223
include_in_schema,
224+
security,
213225
middlewares,
214226
)
215227

@@ -230,6 +242,8 @@ def delete( # type: ignore[override]
230242
include_in_schema: bool = True,
231243
middlewares: Optional[List[Callable[..., Any]]] = None,
232244
):
245+
security = None
246+
233247
return super().delete(
234248
rule,
235249
cors,
@@ -242,6 +256,7 @@ def delete( # type: ignore[override]
242256
tags,
243257
operation_id,
244258
include_in_schema,
259+
security,
245260
middlewares,
246261
)
247262

aws_lambda_powertools/event_handler/openapi/models.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -441,12 +441,13 @@ class SecurityBase(BaseModel):
441441
description: Optional[str] = None
442442

443443
if PYDANTIC_V2:
444-
model_config = {"extra": "allow"}
444+
model_config = {"extra": "allow", "populate_by_name": True}
445445

446446
else:
447447

448448
class Config:
449449
extra = "allow"
450+
allow_population_by_field_name = True
450451

451452

452453
class APIKeyIn(Enum):
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from aws_lambda_powertools.event_handler.openapi.swagger_ui.html import (
2+
generate_swagger_html,
3+
)
4+
from aws_lambda_powertools.event_handler.openapi.swagger_ui.oauth2 import (
5+
OAuth2Config,
6+
generate_oauth2_redirect_html,
7+
)
8+
9+
__all__ = [
10+
"generate_swagger_html",
11+
"generate_oauth2_redirect_html",
12+
"OAuth2Config",
13+
]

aws_lambda_powertools/event_handler/openapi/swagger_ui/html.py

+33-6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,16 @@
1-
def generate_swagger_html(spec: str, path: str, swagger_js: str, swagger_css: str, swagger_base_url: str) -> str:
1+
from typing import Optional
2+
3+
from aws_lambda_powertools.event_handler.openapi.swagger_ui.oauth2 import OAuth2Config
4+
5+
6+
def generate_swagger_html(
7+
spec: str,
8+
path: str,
9+
swagger_js: str,
10+
swagger_css: str,
11+
swagger_base_url: str,
12+
oauth2_config: Optional[OAuth2Config],
13+
) -> str:
214
"""
315
Generate Swagger UI HTML page
416
@@ -8,10 +20,14 @@ def generate_swagger_html(spec: str, path: str, swagger_js: str, swagger_css: st
820
The OpenAPI spec
921
path: str
1022
The path to the Swagger documentation
11-
js_url: str
12-
The URL to the Swagger UI JavaScript file
13-
css_url: str
14-
The URL to the Swagger UI CSS file
23+
swagger_js: str
24+
Swagger UI JavaScript source code or URL
25+
swagger_css: str
26+
Swagger UI CSS source code or URL
27+
swagger_base_url: str
28+
The base URL for Swagger UI
29+
oauth2_config: OAuth2Config, optional
30+
The OAuth2 configuration.
1531
"""
1632

1733
# If Swagger base URL is present, generate HTML content with linked CSS and JavaScript files
@@ -23,6 +39,11 @@ def generate_swagger_html(spec: str, path: str, swagger_js: str, swagger_css: st
2339
swagger_css_content = f"<style>{swagger_css}</style>"
2440
swagger_js_content = f"<script>{swagger_js}</script>"
2541

42+
# Prepare oauth2 config
43+
oauth2_content = (
44+
f"ui.initOAuth({oauth2_config.json(exclude_none=True, exclude_unset=True)});" if oauth2_config else ""
45+
)
46+
2647
return f"""
2748
<!DOCTYPE html>
2849
<html>
@@ -45,6 +66,9 @@ def generate_swagger_html(spec: str, path: str, swagger_js: str, swagger_css: st
4566
{swagger_js_content}
4667
4768
<script>
69+
var currentUrl = new URL(window.location.href);
70+
var baseUrl = currentUrl.protocol + "//" + currentUrl.host + currentUrl.pathname;
71+
4872
var swaggerUIOptions = {{
4973
dom_id: "#swagger-ui",
5074
docExpansion: "list",
@@ -60,11 +84,14 @@ def generate_swagger_html(spec: str, path: str, swagger_js: str, swagger_css: st
6084
],
6185
plugins: [
6286
SwaggerUIBundle.plugins.DownloadUrl
63-
]
87+
],
88+
withCredentials: true,
89+
oauth2RedirectUrl: baseUrl + "?format=oauth2-redirect",
6490
}}
6591
6692
var ui = SwaggerUIBundle(swaggerUIOptions)
6793
ui.specActions.updateUrl('{path}?format=json');
94+
{oauth2_content}
6895
</script>
6996
</html>
7097
""".strip()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# ruff: noqa: E501
2+
import warnings
3+
from typing import Dict, Optional, Sequence
4+
5+
from pydantic import BaseModel, Field, validator
6+
7+
from aws_lambda_powertools.event_handler.openapi.pydantic_loader import PYDANTIC_V2
8+
from aws_lambda_powertools.shared.functions import powertools_dev_is_set
9+
10+
11+
# Based on https://swagger.io/docs/open-source-tools/swagger-ui/usage/oauth2/
12+
class OAuth2Config(BaseModel):
13+
"""
14+
OAuth2 configuration for Swagger UI
15+
"""
16+
17+
# The client ID for the OAuth2 application
18+
clientId: Optional[str] = Field(alias="client_id", default=None)
19+
20+
# The client secret for the OAuth2 application. This is sensitive information and requires the explicit presence
21+
# of the POWERTOOLS_DEV environment variable.
22+
clientSecret: Optional[str] = Field(alias="client_secret", default=None)
23+
24+
# The realm in which the OAuth2 application is registered. Optional.
25+
realm: Optional[str] = Field(default=None)
26+
27+
# The name of the OAuth2 application
28+
appName: str = Field(alias="app_name")
29+
30+
# The scopes that the OAuth2 application requires. Defaults to an empty list.
31+
scopes: Sequence[str] = Field(default=[])
32+
33+
# Additional query string parameters to be included in the OAuth2 request. Defaults to an empty dictionary.
34+
additionalQueryStringParams: Dict[str, str] = Field(alias="additional_query_string_params", default={})
35+
36+
# Whether to use basic authentication with the access code grant type. Defaults to False.
37+
useBasicAuthenticationWithAccessCodeGrant: bool = Field(
38+
alias="use_basic_authentication_with_access_code_grant",
39+
default=False,
40+
)
41+
42+
# Whether to use PKCE with the authorization code grant type. Defaults to False.
43+
usePkceWithAuthorizationCodeGrant: bool = Field(alias="use_pkce_with_authorization_code_grant", default=False)
44+
45+
if PYDANTIC_V2:
46+
model_config = {"extra": "allow"}
47+
else:
48+
49+
class Config:
50+
extra = "allow"
51+
allow_population_by_field_name = True
52+
53+
@validator("clientSecret", always=True)
54+
def client_secret_only_on_dev(cls, v: Optional[str]) -> Optional[str]:
55+
if not v:
56+
return None
57+
58+
if not powertools_dev_is_set():
59+
raise ValueError(
60+
"cannot use client_secret without POWERTOOLS_DEV mode. See "
61+
"https://docs.powertools.aws.dev/lambda/python/latest/#optimizing-for-non-production-environments",
62+
)
63+
else:
64+
warnings.warn(
65+
"OAuth2Config is using client_secret and POWERTOOLS_DEV is set. This reveals sensitive information. "
66+
"DO NOT USE THIS OUTSIDE LOCAL DEVELOPMENT",
67+
stacklevel=2,
68+
)
69+
return v
70+
71+
72+
def generate_oauth2_redirect_html() -> str:
73+
"""
74+
Generates the HTML content for the OAuth2 redirect page.
75+
76+
Source: https://github.com/swagger-api/swagger-ui/blob/master/dist/oauth2-redirect.html
77+
"""
78+
return """
79+
<!doctype html>
80+
<html lang="en-US">
81+
<head>
82+
<title>Swagger UI: OAuth2 Redirect</title>
83+
</head>
84+
<body>
85+
<script>
86+
'use strict';
87+
function run () {
88+
var oauth2 = window.opener.swaggerUIRedirectOauth2;
89+
var sentState = oauth2.state;
90+
var redirectUrl = oauth2.redirectUrl;
91+
var isValid, qp, arr;
92+
93+
if (/code|token|error/.test(window.location.hash)) {
94+
qp = window.location.hash.substring(1).replace('?', '&');
95+
} else {
96+
qp = location.search.substring(1);
97+
}
98+
99+
arr = qp.split("&");
100+
arr.forEach(function (v,i,_arr) { _arr[i] = '"' + v.replace('=', '":"') + '"';});
101+
qp = qp ? JSON.parse('{' + arr.join() + '}',
102+
function (key, value) {
103+
return key === "" ? value : decodeURIComponent(value);
104+
}
105+
) : {};
106+
107+
isValid = qp.state === sentState;
108+
109+
if ((
110+
oauth2.auth.schema.get("flow") === "accessCode" ||
111+
oauth2.auth.schema.get("flow") === "authorizationCode" ||
112+
oauth2.auth.schema.get("flow") === "authorization_code"
113+
) && !oauth2.auth.code) {
114+
if (!isValid) {
115+
oauth2.errCb({
116+
authId: oauth2.auth.name,
117+
source: "auth",
118+
level: "warning",
119+
message: "Authorization may be unsafe, passed state was changed in server. The passed state wasn't returned from auth server."
120+
});
121+
}
122+
123+
if (qp.code) {
124+
delete oauth2.state;
125+
oauth2.auth.code = qp.code;
126+
oauth2.callback({auth: oauth2.auth, redirectUrl: redirectUrl});
127+
} else {
128+
let oauthErrorMsg;
129+
if (qp.error) {
130+
oauthErrorMsg = "["+qp.error+"]: " +
131+
(qp.error_description ? qp.error_description+ ". " : "no accessCode received from the server. ") +
132+
(qp.error_uri ? "More info: "+qp.error_uri : "");
133+
}
134+
135+
oauth2.errCb({
136+
authId: oauth2.auth.name,
137+
source: "auth",
138+
level: "error",
139+
message: oauthErrorMsg || "[Authorization failed]: no accessCode received from the server."
140+
});
141+
}
142+
} else {
143+
oauth2.callback({auth: oauth2.auth, token: qp, isValid: isValid, redirectUrl: redirectUrl});
144+
}
145+
window.close();
146+
}
147+
148+
if (document.readyState !== 'loading') {
149+
run();
150+
} else {
151+
document.addEventListener('DOMContentLoaded', function () {
152+
run();
153+
});
154+
}
155+
</script>
156+
</body>
157+
</html>
158+
""".strip()

0 commit comments

Comments
 (0)