Skip to content

Commit 8215dd2

Browse files
evakravinavinsonimufaddal-rohawalaqidewenwhenHappyAmazonian
authored
feat: override jumpstart content bucket (#2901)
Co-authored-by: Navin Soni <[email protected]> Co-authored-by: Mufaddal Rohawala <[email protected]> Co-authored-by: qidewenwhen <[email protected]> Co-authored-by: HappyAmazonian <[email protected]>
1 parent 09d3fd1 commit 8215dd2

File tree

4 files changed

+26
-3
lines changed

4 files changed

+26
-3
lines changed

doc/overview.rst

+2-3
Original file line numberDiff line numberDiff line change
@@ -746,6 +746,7 @@ see `Model <https://sagemaker.readthedocs.io/en/stable/api/inference/model.html
746746
.. code:: python
747747
748748
from sagemaker.model import Model
749+
from sagemaker.predictor import Predictor
749750
from sagemaker.session import Session
750751
751752
# Create the SageMaker model instance
@@ -755,6 +756,7 @@ see `Model <https://sagemaker.readthedocs.io/en/stable/api/inference/model.html
755756
   source_dir=script_uri,
756757
   entry_point="inference.py",
757758
   role=Session().get_caller_identity_arn(),
759+
   predictor_cls=Predictor,
758760
)
759761
760762
Save the output from deploying the model to a variable named
@@ -766,12 +768,9 @@ Deployment may take about 5 minutes.
766768

767769
.. code:: python
768770
769-
from sagemaker.predictor import Predictor
770-
771771
predictor = model.deploy(
772772
   initial_instance_count=instance_count,
773773
   instance_type=instance_type,
774-
   predictor_cls=Predictor
775774
)
776775
777776
Because ``catboost`` and ``lightgbm`` rely on the PyTorch Deep Learning Containers

src/sagemaker/jumpstart/constants.py

+2
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,5 @@
122122
TRAINING_ENTRY_POINT_SCRIPT_NAME = "transfer_learning.py"
123123

124124
SUPPORTED_JUMPSTART_SCOPES = set(scope.value for scope in JumpStartScriptScope)
125+
126+
ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE = "AWS_JUMPSTART_CONTENT_BUCKET_OVERRIDE"

src/sagemaker/jumpstart/utils.py

+9
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""This module contains utilities related to SageMaker JumpStart."""
1414
from __future__ import absolute_import
1515
import logging
16+
import os
1617
from typing import Dict, List, Optional
1718
from urllib.parse import urlparse
1819
from packaging.version import Version
@@ -60,6 +61,14 @@ def get_jumpstart_content_bucket(region: str) -> str:
6061
Raises:
6162
RuntimeError: If JumpStart is not launched in ``region``.
6263
"""
64+
65+
if (
66+
constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE in os.environ
67+
and len(os.environ[constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE]) > 0
68+
):
69+
bucket_override = os.environ[constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE]
70+
LOGGER.info("Using JumpStart bucket override: '%s'", bucket_override)
71+
return bucket_override
6372
try:
6473
return constants.JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT[region].content_bucket
6574
except KeyError:

tests/unit/sagemaker/jumpstart/test_utils.py

+13
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
14+
import os
1415
from mock.mock import Mock, patch
1516
import pytest
1617
import random
1718
from sagemaker.jumpstart import utils
1819
from sagemaker.jumpstart.constants import (
20+
ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE,
1921
JUMPSTART_BUCKET_NAME_SET,
2022
JUMPSTART_REGION_NAME_SET,
2123
JumpStartScriptScope,
@@ -40,6 +42,17 @@ def test_get_jumpstart_content_bucket():
4042
utils.get_jumpstart_content_bucket(bad_region)
4143

4244

45+
def test_get_jumpstart_content_bucket_override():
46+
with patch.dict(os.environ, {ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE: "some-val"}):
47+
with patch("logging.Logger.info") as mocked_info_log:
48+
random_region = "random_region"
49+
assert "some-val" == utils.get_jumpstart_content_bucket(random_region)
50+
mocked_info_log.assert_called_once_with(
51+
"Using JumpStart bucket override: '%s'",
52+
"some-val",
53+
)
54+
55+
4356
def test_get_jumpstart_launched_regions_message():
4457

4558
with patch("sagemaker.jumpstart.constants.JUMPSTART_REGION_NAME_SET", {}):

0 commit comments

Comments
 (0)