Skip to content

Commit df1c9dd

Browse files
edwardpsEdward Sun
authored andcommitted
feat: Partner App Auth Provider for SDK support (#1548)
Co-authored-by: Edward Sun <[email protected]>
1 parent b09fbb8 commit df1c9dd

File tree

7 files changed

+532
-0
lines changed

7 files changed

+532
-0
lines changed

src/sagemaker/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -74,5 +74,6 @@
7474
)
7575

7676
from sagemaker.debugger import ProfilerConfig, Profiler # noqa: F401
77+
from sagemaker.partner_app.auth_provider import PartnerAppAuthProvider # noqa: F401
7778

7879
__version__ = importlib_metadata.version("sagemaker")

src/sagemaker/partner_app/__init__.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""__init__ file for sagemaker.partner_app.auth_provider"""
14+
from __future__ import absolute_import
15+
16+
from sagemaker.partner_app.auth_provider import PartnerAppAuthProvider # noqa: F401
+129
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
14+
"""The SageMaker partner application SDK auth module"""
15+
from __future__ import absolute_import
16+
17+
import os
18+
import re
19+
from typing import Dict, Tuple
20+
21+
import boto3
22+
from botocore.auth import SigV4Auth
23+
from botocore.credentials import Credentials
24+
from requests.auth import AuthBase
25+
from requests.models import PreparedRequest
26+
from sagemaker.partner_app.auth_utils import PartnerAppAuthUtils
27+
28+
SERVICE_NAME = "sagemaker"
29+
AWS_PARTNER_APP_ARN_REGEX = r"arn:aws[a-z\-]*:sagemaker:[a-z0-9\-]*:[0-9]{12}:partner-app\/.*"
30+
31+
32+
class RequestsAuth(AuthBase):
33+
"""Requests authentication class for SigV4 header generation.
34+
35+
This class is used to generate the SigV4 header and add it to the request headers.
36+
"""
37+
38+
def __init__(self, sigv4: SigV4Auth, app_arn: str):
39+
"""Initialize the RequestsAuth class.
40+
41+
Args:
42+
sigv4 (SigV4Auth): SigV4Auth object
43+
app_arn (str): Application ARN
44+
"""
45+
self.sigv4 = sigv4
46+
self.app_arn = app_arn
47+
48+
def __call__(self, request: PreparedRequest) -> PreparedRequest:
49+
"""Callback function to generate the SigV4 header and add it to the request headers.
50+
51+
Args:
52+
request (PreparedRequest): PreparedRequest object
53+
54+
Returns:
55+
PreparedRequest: PreparedRequest object with the SigV4 header added
56+
"""
57+
url, signed_headers = PartnerAppAuthUtils.get_signed_request(
58+
sigv4=self.sigv4,
59+
app_arn=self.app_arn,
60+
url=request.url,
61+
method=request.method,
62+
headers=request.headers,
63+
body=request.body,
64+
)
65+
request.url = url
66+
request.headers.update(signed_headers)
67+
68+
return request
69+
70+
71+
class PartnerAppAuthProvider:
72+
"""The SageMaker partner application SDK auth provider class"""
73+
74+
def __init__(self, credentials: Credentials = None):
75+
"""Initialize the PartnerAppAuthProvider class.
76+
77+
Args:
78+
credentials (Credentials, optional): AWS credentials. Defaults to None.
79+
Raises:
80+
ValueError: If the AWS_PARTNER_APP_ARN environment variable is not set or is invalid.
81+
"""
82+
self.app_arn = os.getenv("AWS_PARTNER_APP_ARN")
83+
if self.app_arn is None:
84+
raise ValueError("Must specify the AWS_PARTNER_APP_ARN environment variable")
85+
86+
app_arn_regex_match = re.search(AWS_PARTNER_APP_ARN_REGEX, self.app_arn)
87+
if app_arn_regex_match is None:
88+
raise ValueError("Must specify a valid AWS_PARTNER_APP_ARN environment variable")
89+
90+
split_arn = self.app_arn.split(":")
91+
self.region = split_arn[3]
92+
93+
self.credentials = (
94+
credentials if credentials is not None else boto3.Session().get_credentials()
95+
)
96+
self.sigv4 = SigV4Auth(self.credentials, SERVICE_NAME, self.region)
97+
98+
def get_signed_request(
99+
self, url: str, method: str, headers: dict, body: object
100+
) -> Tuple[str, Dict[str, str]]:
101+
"""Generate the SigV4 header and add it to the request headers.
102+
103+
Args:
104+
url (str): Request URL
105+
method (str): HTTP method
106+
headers (dict): Request headers
107+
body (object): Request body
108+
109+
Returns:
110+
tuple: (url, headers)
111+
"""
112+
return PartnerAppAuthUtils.get_signed_request(
113+
sigv4=self.sigv4,
114+
app_arn=self.app_arn,
115+
url=url,
116+
method=method,
117+
headers=headers,
118+
body=body,
119+
)
120+
121+
def get_auth(self) -> RequestsAuth:
122+
"""Returns the callback class (RequestsAuth) used for generating the SigV4 header.
123+
124+
Returns:
125+
RequestsAuth: Callback Object which will calculate the header just before
126+
request submission.
127+
"""
128+
129+
return RequestsAuth(self.sigv4, os.environ["AWS_PARTNER_APP_ARN"])
+124
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
14+
"""Partner App Auth Utils Module"""
15+
16+
from __future__ import absolute_import
17+
18+
from hashlib import sha256
19+
import functools
20+
from typing import Tuple, Dict
21+
22+
from botocore.auth import SigV4Auth
23+
from botocore.awsrequest import AWSRequest
24+
25+
HEADER_CONNECTION = "Connection"
26+
HEADER_X_AMZ_TARGET = "X-Amz-Target"
27+
HEADER_AUTHORIZATION = "Authorization"
28+
HEADER_MLAPP_SM_APP_SERVER_ARN = "X-Mlapp-Sm-App-Server-Arn"
29+
HEADER_PARTNER_APP_AUTHORIZATION = "X-Amz-Partner-App-Authorization"
30+
HEADER_X_AMZ_CONTENT_SHA_256 = "X-Amz-Content-SHA256"
31+
CALL_PARTNER_APP_API_ACTION = "SageMaker.CallPartnerAppApi"
32+
33+
PAYLOAD_BUFFER = 1024 * 1024
34+
EMPTY_SHA256_HASH = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
35+
UNSIGNED_PAYLOAD = "UNSIGNED-PAYLOAD"
36+
37+
38+
class PartnerAppAuthUtils:
39+
"""Partner App Auth Utils Class"""
40+
41+
@staticmethod
42+
def get_signed_request(
43+
sigv4: SigV4Auth, app_arn: str, url: str, method: str, headers: dict, body: object
44+
) -> Tuple[str, Dict[str, str]]:
45+
"""Generate the SigV4 header and add it to the request headers.
46+
47+
Args:
48+
sigv4 (SigV4Auth): SigV4Auth object
49+
app_arn (str): Application ARN
50+
url (str): Request URL
51+
method (str): HTTP method
52+
headers (dict): Request headers
53+
body (object): Request body
54+
Returns:
55+
tuple: (url, headers)
56+
"""
57+
# Move API key to X-Amz-Partner-App-Authorization
58+
if HEADER_AUTHORIZATION in headers:
59+
headers[HEADER_PARTNER_APP_AUTHORIZATION] = headers[HEADER_AUTHORIZATION]
60+
61+
# App Arn
62+
headers[HEADER_MLAPP_SM_APP_SERVER_ARN] = app_arn
63+
64+
# IAM Action
65+
headers[HEADER_X_AMZ_TARGET] = CALL_PARTNER_APP_API_ACTION
66+
67+
# Body
68+
headers[HEADER_X_AMZ_CONTENT_SHA_256] = PartnerAppAuthUtils.get_body_header(body)
69+
70+
# Connection header is excluded from server-side signature calculation
71+
connection_header = headers[HEADER_CONNECTION] if HEADER_CONNECTION in headers else None
72+
73+
if HEADER_CONNECTION in headers:
74+
del headers[HEADER_CONNECTION]
75+
76+
# Spaces are encoded as %20
77+
# TODO - confirm the motivation
78+
if method in ("GET", "DEL"):
79+
url = url.replace("+", "%20")
80+
81+
# Calculate SigV4 header
82+
aws_request = AWSRequest(
83+
method=method,
84+
url=url,
85+
headers=headers,
86+
data=body,
87+
)
88+
sigv4.add_auth(aws_request)
89+
90+
# Reassemble headers
91+
final_headers = dict(aws_request.headers.items())
92+
if connection_header is not None:
93+
final_headers[HEADER_CONNECTION] = connection_header
94+
95+
return (url, final_headers)
96+
97+
@staticmethod
98+
def get_body_header(body: object):
99+
"""Calculate the body header for the SigV4 header.
100+
101+
Args:
102+
body (object): Request body
103+
"""
104+
if body and hasattr(body, "seek"):
105+
position = body.tell()
106+
read_chunksize = functools.partial(body.read, PAYLOAD_BUFFER)
107+
checksum = sha256()
108+
for chunk in iter(read_chunksize, b""):
109+
checksum.update(chunk)
110+
hex_checksum = checksum.hexdigest()
111+
body.seek(position)
112+
return hex_checksum
113+
114+
if body and not isinstance(body, bytes):
115+
# Body is of a class we don't recognize, so don't sign the payload
116+
return UNSIGNED_PAYLOAD
117+
118+
if body:
119+
# The request serialization has ensured that
120+
# request.body is a bytes() type.
121+
return sha256(body).hexdigest()
122+
123+
# Body is None
124+
return EMPTY_SHA256_HASH

tests/unit/sagemaker/partner_app/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)