Skip to content

feat: override jumpstart content bucket #2901

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Feb 19, 2022
5 changes: 2 additions & 3 deletions doc/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,7 @@ see `Model <https://sagemaker.readthedocs.io/en/stable/api/inference/model.html
.. code:: python

from sagemaker.model import Model
from sagemaker.predictor import Predictor
from sagemaker.session import Session

# Create the SageMaker model instance
Expand All @@ -755,6 +756,7 @@ see `Model <https://sagemaker.readthedocs.io/en/stable/api/inference/model.html
   source_dir=script_uri,
   entry_point="inference.py",
   role=Session().get_caller_identity_arn(),
   predictor_cls=Predictor,
)

Save the output from deploying the model to a variable named
Expand All @@ -766,12 +768,9 @@ Deployment may take about 5 minutes.

.. code:: python

from sagemaker.predictor import Predictor

predictor = model.deploy(
   initial_instance_count=instance_count,
   instance_type=instance_type,
   predictor_cls=Predictor
)

Because ``catboost`` and ``lightgbm`` rely on the PyTorch Deep Learning Containers
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/jumpstart/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,5 @@
TRAINING_ENTRY_POINT_SCRIPT_NAME = "transfer_learning.py"

SUPPORTED_JUMPSTART_SCOPES = set(scope.value for scope in JumpStartScriptScope)

ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE = "AWS_JUMPSTART_CONTENT_BUCKET_OVERRIDE"
9 changes: 9 additions & 0 deletions src/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"""This module contains utilities related to SageMaker JumpStart."""
from __future__ import absolute_import
import logging
import os
from typing import Dict, List, Optional
from urllib.parse import urlparse
from packaging.version import Version
Expand Down Expand Up @@ -60,6 +61,14 @@ def get_jumpstart_content_bucket(region: str) -> str:
Raises:
RuntimeError: If JumpStart is not launched in ``region``.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this used downstream and what will happen for multiple values?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is used for getting the bucket for all JumpStart resources.

"""

if (
constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE in os.environ
and len(os.environ[constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE]) > 0
):
bucket_override = os.environ[constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE]
LOGGER.info("Using JumpStart bucket override: '%s'", bucket_override)
return bucket_override
try:
return constants.JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT[region].content_bucket
except KeyError:
Expand Down
13 changes: 13 additions & 0 deletions tests/unit/sagemaker/jumpstart/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import
import os
from mock.mock import Mock, patch
import pytest
import random
from sagemaker.jumpstart import utils
from sagemaker.jumpstart.constants import (
ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE,
JUMPSTART_BUCKET_NAME_SET,
JUMPSTART_REGION_NAME_SET,
JumpStartScriptScope,
Expand All @@ -40,6 +42,17 @@ def test_get_jumpstart_content_bucket():
utils.get_jumpstart_content_bucket(bad_region)


def test_get_jumpstart_content_bucket_override():
with patch.dict(os.environ, {ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE: "some-val"}):
with patch("logging.Logger.info") as mocked_info_log:
random_region = "random_region"
assert "some-val" == utils.get_jumpstart_content_bucket(random_region)
mocked_info_log.assert_called_once_with(
"Using JumpStart bucket override: '%s'",
"some-val",
)


def test_get_jumpstart_launched_regions_message():

with patch("sagemaker.jumpstart.constants.JUMPSTART_REGION_NAME_SET", {}):
Expand Down