Skip to content

Commit 9517bdd

Browse files
authored
feature: Add presigned URLs for interactive apps (#4086)
1 parent 300cd17 commit 9517bdd

File tree

7 files changed

+629
-136
lines changed

7 files changed

+629
-136
lines changed

src/sagemaker/estimator.py

+71
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@
6565
validate_source_code_input_against_pipeline_variables,
6666
)
6767
from sagemaker.inputs import TrainingInput, FileSystemInput
68+
from sagemaker.interactive_apps import SupportedInteractiveAppTypes
69+
from sagemaker.interactive_apps.tensorboard import TensorBoardApp
6870
from sagemaker.instance_group import InstanceGroup
6971
from sagemaker.utils import instance_supports_kms
7072
from sagemaker.job import _Job
@@ -750,6 +752,8 @@ def __init__(
750752
# Internal flag
751753
self._is_output_path_set_from_default_bucket_and_prefix = False
752754

755+
self.tensorboard_app = TensorBoardApp(region=self.sagemaker_session.boto_region_name)
756+
753757
@abstractmethod
754758
def training_image_uri(self):
755759
"""Return the Docker image to use for training.
@@ -2256,6 +2260,73 @@ def update_profiler(
22562260

22572261
_TrainingJob.update(self, profiler_rule_configs, profiler_config_request_dict)
22582262

2263+
def get_app_url(
2264+
self,
2265+
app_type,
2266+
open_in_default_web_browser=True,
2267+
create_presigned_domain_url=False,
2268+
domain_id=None,
2269+
user_profile_name=None,
2270+
optional_create_presigned_url_kwargs=None,
2271+
):
2272+
"""Generate a URL to help access the specified app hosted in Amazon SageMaker Studio.
2273+
2274+
Args:
2275+
app_type (str or SupportedInteractiveAppTypes): Required. The app type available in
2276+
SageMaker Studio to return a URL to.
2277+
open_in_default_web_browser (bool): Optional. When True, the URL will attempt to be
2278+
opened in the environment's default web browser. Otherwise, the resulting URL will
2279+
be returned by this function.
2280+
Default: ``True``
2281+
create_presigned_domain_url (bool): Optional. Determines whether a presigned domain URL
2282+
should be generated instead of an unsigned URL. This only applies when called from
2283+
outside of a SageMaker Studio environment. If this is set to True inside of a
2284+
SageMaker Studio environment, it will be ignored.
2285+
Default: ``False``
2286+
domain_id (str): Optional. The AWS Studio domain that the resulting app will use. If
2287+
code is executing in a Studio environment and this was not supplied, this will be
2288+
automatically detected. If not supplied and running in a non-Studio environment, it
2289+
is up to the derived class on how to handle that, but in general, a redirect to a
2290+
landing page can be expected.
2291+
Default: ``None``
2292+
user_profile_name (str): Optional. The AWS Studio user profile that the resulting app
2293+
will use. If code is executing in a Studio environment and this was not supplied,
2294+
this will be automatically detected. If not supplied and running in a
2295+
non-Studio environment, it is up to the derived class on how to handle that, but in
2296+
general, a redirect to a landing page can be expected.
2297+
Default: ``None``
2298+
optional_create_presigned_url_kwargs (dict): Optional. This parameter
2299+
should be passed when a user outside of Studio wants a presigned URL to the
2300+
TensorBoard application and wants to modify the optional parameters of the
2301+
create_presigned_domain_url call.
2302+
Default: ``None``
2303+
Returns:
2304+
str: A URL for the requested app in SageMaker Studio.
2305+
"""
2306+
url = None
2307+
2308+
# Get app_type in lower str format
2309+
if isinstance(app_type, SupportedInteractiveAppTypes):
2310+
app_type = app_type.name
2311+
app_type = app_type.lower()
2312+
2313+
if app_type == SupportedInteractiveAppTypes.TENSORBOARD.name.lower():
2314+
training_job_name = None
2315+
if self._current_job_name:
2316+
training_job_name = self._current_job_name
2317+
url = self.tensorboard_app.get_app_url(
2318+
training_job_name=training_job_name,
2319+
open_in_default_web_browser=open_in_default_web_browser,
2320+
create_presigned_domain_url=create_presigned_domain_url,
2321+
domain_id=domain_id,
2322+
user_profile_name=user_profile_name,
2323+
optional_create_presigned_url_kwargs=optional_create_presigned_url_kwargs,
2324+
)
2325+
else:
2326+
raise ValueError(f"{app_type} does not support URL retrieval.")
2327+
2328+
return url
2329+
22592330

22602331
class _TrainingJob(_Job):
22612332
"""Placeholder docstring"""

src/sagemaker/interactive_apps/__init__.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,21 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""Classes for using debugger and profiler with Amazon SageMaker."""
13+
"""Classes for starting/accessing apps hosted on Amazon SageMaker Studio."""
14+
1415
from __future__ import absolute_import
1516

17+
from enum import Enum
18+
1619
from sagemaker.interactive_apps.tensorboard import ( # noqa: F401
1720
TensorBoardApp,
1821
)
1922
from sagemaker.interactive_apps.detail_profiler_app import ( # noqa: F401
2023
DetailProfilerApp,
2124
)
25+
26+
27+
class SupportedInteractiveAppTypes(Enum):
28+
"""SupportedInteractiveAppTypes indicates which apps are supported."""
29+
30+
TENSORBOARD = 1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""A base class for starting/accessing apps hosted on Amazon SageMaker Studio"""
14+
15+
from __future__ import absolute_import
16+
17+
import abc
18+
import base64
19+
import json
20+
import logging
21+
import os
22+
import re
23+
import webbrowser
24+
25+
from typing import Optional
26+
import boto3
27+
from sagemaker.session import Session, NOTEBOOK_METADATA_FILE
28+
29+
logger = logging.getLogger(__name__)
30+
31+
32+
class BaseInteractiveApp(abc.ABC):
33+
"""BaseInteractiveApp is a base class for creating/accessing apps hosted on SageMaker."""
34+
35+
def __init__(
36+
self,
37+
region: Optional[str] = None,
38+
):
39+
"""Initialize a BaseInteractiveApp object.
40+
41+
Args:
42+
region (str): Optional. The AWS Region, e.g. us-east-1. If not specified,
43+
one is created using the default AWS configuration chain.
44+
Default: ``None``
45+
"""
46+
if isinstance(region, str):
47+
self.region = region
48+
else:
49+
try:
50+
self.region = Session().boto_region_name
51+
except ValueError:
52+
raise ValueError(
53+
"Failed to get the Region information from the default config. Please either "
54+
"pass your Region manually as an input argument or set up the local AWS"
55+
" configuration."
56+
)
57+
58+
self._sagemaker_client = boto3.client("sagemaker", region_name=self.region)
59+
# Used to store domain and user profile info retrieved from Studio environment.
60+
self._domain_id = None
61+
self._user_profile_name = None
62+
self._get_domain_and_user()
63+
64+
def __str__(self):
65+
"""Return str(self)."""
66+
return f"{type(self).__name__}(region={self.region})"
67+
68+
def __repr__(self):
69+
"""Return repr(self)."""
70+
return self.__str__()
71+
72+
def _get_domain_and_user(self):
73+
"""Get and validate studio domain id and user profile from studio environment."""
74+
if not self._is_in_studio():
75+
return
76+
77+
try:
78+
with open(NOTEBOOK_METADATA_FILE, "rb") as metadata_file:
79+
metadata = json.loads(metadata_file.read())
80+
if not self._validate_domain_id(
81+
metadata.get("DomainId")
82+
) or not self._validate_user_profile_name(metadata.get("UserProfileName")):
83+
logger.warning(
84+
"NOTEBOOK_METADATA_FILE detected but failed to get valid domain and user"
85+
" from it."
86+
)
87+
return
88+
self._domain_id = metadata.get("DomainId")
89+
self._user_profile_name = metadata.get("UserProfileName")
90+
except OSError as err:
91+
logger.warning("Could not load Studio metadata due to unexpected error. %s", err)
92+
93+
def _get_presigned_url(
94+
self,
95+
create_presigned_url_kwargs: dict,
96+
redirect: Optional[str] = None,
97+
state: Optional[str] = None,
98+
):
99+
"""Generate a presigned URL to access a user's domain / user profile.
100+
101+
Optional state and redirect parameters can be used to to have presigned URL automatically
102+
redirect to a specific app and provide modifying data.
103+
104+
Args:
105+
create_presigned_url_kwargs (dict): Required. This dictionary should include the
106+
parameters that will be used when calling create_presigned_domain_url via the boto3
107+
client. At a minimum, this should include the "DomainId" and "UserProfileName"
108+
parameters as defined by create_presigned_domain_url's documentation.
109+
Default: ``None``
110+
redirect (str): Optional. This value will be appended to the resulting presigned URL
111+
in the format "&redirect=<redirect parameter>". This is used to automatically
112+
redirect the user into a specific Studio app.
113+
Default: ``None``
114+
state (str): Optional. This value will be appended to the resulting presigned URL
115+
in the format "&state=<state parameter base64 encoded>". This is used to
116+
automatically apply a state to the given app. Should be used in conjuction with
117+
the redirect parameter.
118+
Default: ``None``
119+
120+
Returns:
121+
str: A presigned URL.
122+
"""
123+
response = self._sagemaker_client.create_presigned_domain_url(**create_presigned_url_kwargs)
124+
if response["ResponseMetadata"]["HTTPStatusCode"] == 200:
125+
url = response["AuthorizedUrl"]
126+
else:
127+
raise ValueError(
128+
"An invalid status code was returned when creating a presigned URL."
129+
f" See response for more: {response}"
130+
)
131+
132+
if redirect:
133+
url += f"&redirect={redirect}"
134+
135+
if state:
136+
url += f"&state={base64.b64encode(bytes(state, 'utf-8')).decode('utf-8')}"
137+
138+
logger.warning(
139+
"A presigned domain URL was generated. This is sensitive and should not be shared with"
140+
" others."
141+
)
142+
143+
return url
144+
145+
def _is_in_studio(self):
146+
"""Check to see if NOTEBOOK_METADATA_FILE exists to verify Studio environment."""
147+
return os.path.isfile(NOTEBOOK_METADATA_FILE)
148+
149+
def _open_url_in_web_browser(self, url: str):
150+
"""Open a URL in the default web browser.
151+
152+
Args:
153+
url (str): The URL to open.
154+
"""
155+
webbrowser.open(url)
156+
157+
def _validate_domain_id(self, domain_id: Optional[str] = None):
158+
"""Validate domain id format.
159+
160+
Args:
161+
domain_id (str): Optional. The domain ID to validate. If one is not supplied,
162+
self._domain_id will be used instead.
163+
Default: ``None``
164+
165+
Returns:
166+
bool: Whether the supplied domain ID is valid.
167+
"""
168+
if domain_id is None:
169+
domain_id = self._domain_id
170+
if domain_id is None or len(domain_id) > 63:
171+
return False
172+
return True
173+
174+
def _validate_job_name(self, job_name: str):
175+
"""Validate training job name format.
176+
177+
Args:
178+
job_name (str): The job name to validate.
179+
180+
Returns:
181+
bool: Whether the supplied job name is valid.
182+
"""
183+
job_name_regex = "^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}"
184+
if not re.fullmatch(job_name_regex, job_name):
185+
raise ValueError(
186+
f"Invalid job name. Job name must match regular expression {job_name_regex}"
187+
)
188+
189+
def _validate_user_profile_name(self, user_profile_name: Optional[str] = None):
190+
"""Validate user profile name format.
191+
192+
Args:
193+
user_profile_name (str): Optional. The user profile name to validate. If one is not
194+
supplied, self._user_profile_name will be used instead.
195+
Default: ``None``
196+
197+
Returns:
198+
bool: Whether the supplied user profile name is valid.
199+
"""
200+
if user_profile_name is None:
201+
user_profile_name = self._user_profile_name
202+
user_profile_name_regex = "^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}"
203+
if user_profile_name is None or not re.fullmatch(
204+
user_profile_name_regex, user_profile_name
205+
):
206+
return False
207+
return True
208+
209+
def _validate_domain_and_user(self):
210+
"""Helper function to consolidate validation calls."""
211+
return self._validate_domain_id() and self._validate_user_profile_name()
212+
213+
@abc.abstractmethod
214+
def get_app_url(self):
215+
"""Abstract method to generate a URL to help access the application in Studio.
216+
217+
Classes that inherit from BaseInteractiveApp should implement and override with what
218+
parameters are needed for its specific use case.
219+
"""

0 commit comments

Comments
 (0)