Skip to content

Commit 08f97bf

Browse files
Initial tests
1 parent c6e32ed commit 08f97bf

File tree

2 files changed

+31
-38
lines changed

2 files changed

+31
-38
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
11
"""Advanced feature flags utility"""
2+
23
from .aws_auth import AuthProvider, AWSServicePrefix, AWSSigV4Auth, JWTAuth
34

4-
__all__ = [
5-
"AuthProvider",
6-
"AWSServicePrefix",
7-
"AWSSigV4Auth",
8-
"JWTAuth"
9-
]
5+
__all__ = ["AuthProvider", "AWSServicePrefix", "AWSSigV4Auth", "JWTAuth"]

aws_lambda_powertools/utilities/auth/aws_auth.py

+29-32
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
1+
from __future__ import annotations
12

2-
from typing import Optional
3-
from enum import Enum
4-
from botocore.auth import SigV4Auth
5-
from botocore.awsrequest import AWSRequest
6-
from botocore.credentials import Credentials
7-
import botocore.session
8-
from abc import ABC, abstractmethod
9-
3+
import base64
104
import json
5+
from enum import Enum
6+
from typing import Optional
117

8+
import botocore.session
129
import urllib3
13-
import base64
10+
from botocore.auth import SigV4Auth
11+
from botocore.awsrequest import AWSRequest
12+
from botocore.credentials import Credentials, ReadOnlyCredentials
1413

1514

1615
def _authorization_header(client_id: str, client_secret: str) -> str:
@@ -28,6 +27,7 @@ def _authorization_header(client_id: str, client_secret: str) -> str:
2827
encoded_auth_string = base64.b64encode(auth_string.encode("utf-8")).decode("utf-8")
2928
return f"Basic {encoded_auth_string}"
3029

30+
3131
def _get_token(response: dict) -> str:
3232
"""
3333
Gets the token from the response
@@ -45,7 +45,8 @@ def _get_token(response: dict) -> str:
4545
else:
4646
raise Exception("Unable to get token from response")
4747

48-
def _request_access_token(auth_endpoint: str, body: dict, headers: dict) -> dict:
48+
49+
def _request_access_token(auth_endpoint: str, body: dict, headers: dict) -> str:
4950
"""
5051
Gets the token from the Auth0 authentication endpoint
5152
@@ -71,14 +72,10 @@ def _request_access_token(auth_endpoint: str, body: dict, headers: dict) -> dict
7172
response = http.request("POST", auth_endpoint, headers=headers, body=json_body)
7273
response = response.json()
7374
return _get_token(response)
74-
except urllib3.exceptions.RequestError as error:
75+
except (urllib3.exceptions.RequestError, urllib3.exceptions.HTTPError) as error:
7576
# If there is an error with the request, handle it here
76-
raise error
77-
except urllib3.exceptions.HTTPError as error:
78-
raise error
79-
80-
81-
77+
# REVIEW: CREATE A CUSTOM EXCEPTION FOR THIS
78+
raise Exception(error)
8279

8380

8481
class AWSServicePrefix(Enum):
@@ -88,6 +85,7 @@ class AWSServicePrefix(Enum):
8885
URLs:
8986
https://docs.aws.amazon.com/service-authorization/latest/reference/reference_policies_actions-resources-contextkeys.html
9087
"""
88+
9189
LATTICE = "vpc-lattice-svcs"
9290
RESTAPI = "execute-api"
9391
HTTPAPI = "apigateway"
@@ -98,10 +96,12 @@ class AuthProvider(Enum):
9896
"""
9997
Auth Provider - Enumerations of the supported authentication providers
10098
"""
99+
101100
AUTH0 = "auth0"
102101
COGNITO = "cognito"
103102
OKTA = "okta"
104103

104+
105105
class AWSSigV4Auth:
106106
"""
107107
Authenticating Requests (AWS Signature Version 4)
@@ -128,11 +128,10 @@ class AWSSigV4Auth:
128128
>>> auth = AWSSigV4Auth(region="us-east-2", service=AWSServicePrefix.LATTICE, url="https://test-fake-service.vpc-lattice-svcs.us-east-2.on.aws")
129129
"""
130130

131-
132131
def __init__(
133132
self,
134133
url: str,
135-
region: Optional[str],
134+
region: str,
136135
body: Optional[str] = None,
137136
params: Optional[dict] = None,
138137
headers: Optional[dict] = None,
@@ -151,6 +150,8 @@ def __init__(
151150
self.params = params
152151
self.headers = headers
153152

153+
self.credentials: Credentials | ReadOnlyCredentials
154+
154155
if access_key and secret_key and token:
155156
self.access_key = access_key
156157
self.secret_key = secret_key
@@ -178,18 +179,17 @@ def __call__(self):
178179
return self.signed_request
179180

180181

181-
182182
class JWTAuth:
183183

184184
def __init__(
185-
self,
186-
client_id: str,
187-
client_secret: str,
188-
auth_endpoint: str,
189-
provider: Enum = AuthProvider.COGNITO,
190-
audience: Optional[str] = None,
191-
scope: Optional[list] = None
192-
):
185+
self,
186+
client_id: str,
187+
client_secret: str,
188+
auth_endpoint: str,
189+
provider: Enum = AuthProvider.COGNITO,
190+
audience: Optional[str] = None,
191+
scope: Optional[list] = None,
192+
):
193193

194194
self.client_id = client_id
195195
self.client_secret = client_secret
@@ -230,7 +230,4 @@ def __init__(
230230
if scope:
231231
self.body["scope"] = " ".join(self.scope)
232232

233-
234-
response = _request_access_token(auth_endpoint=self.auth_endpoint, body=self.body, headers=self.headers)
235-
236-
233+
# response = _request_access_token(auth_endpoint=self.auth_endpoint, body=self.body, headers=self.headers) # noqa ERA001

0 commit comments

Comments
 (0)