Skip to content

Commit 327f8f1

Browse files
committed
feat: override jumpstart content bucket
1 parent c437191 commit 327f8f1

File tree

3 files changed

+17
-0
lines changed

3 files changed

+17
-0
lines changed

src/sagemaker/jumpstart/constants.py

+2
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,5 @@
122122
TRAINING_ENTRY_POINT_SCRIPT_NAME = "transfer_learning.py"
123123

124124
SUPPORTED_JUMPSTART_SCOPES = set(scope.value for scope in JumpStartScriptScope)
125+
126+
ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE = "AWS_JUMPSTART_CONTENT_BUCKET_OVERRIDE"

src/sagemaker/jumpstart/utils.py

+7
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""This module contains utilities related to SageMaker JumpStart."""
1414
from __future__ import absolute_import
1515
import logging
16+
import os
1617
from typing import Dict, List, Optional
1718
from urllib.parse import urlparse
1819
from packaging.version import Version
@@ -60,6 +61,12 @@ def get_jumpstart_content_bucket(region: str) -> str:
6061
Raises:
6162
RuntimeError: If JumpStart is not launched in ``region``.
6263
"""
64+
65+
if (
66+
constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE in os.environ
67+
and len(os.environ[constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE]) > 0
68+
):
69+
return os.environ[constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE]
6370
try:
6471
return constants.JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT[region].content_bucket
6572
except KeyError:

tests/unit/sagemaker/jumpstart/test_utils.py

+8
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
14+
import os
1415
from mock.mock import Mock, patch
1516
import pytest
1617
import random
1718
from sagemaker.jumpstart import utils
1819
from sagemaker.jumpstart.constants import (
20+
ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE,
1921
JUMPSTART_BUCKET_NAME_SET,
2022
JUMPSTART_REGION_NAME_SET,
2123
JumpStartScriptScope,
@@ -40,6 +42,12 @@ def test_get_jumpstart_content_bucket():
4042
utils.get_jumpstart_content_bucket(bad_region)
4143

4244

45+
def test_get_jumpstart_content_bucket_override():
46+
with patch.dict(os.environ, {ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE: "some-val"}):
47+
random_region = "random_region"
48+
assert "some-val" == utils.get_jumpstart_content_bucket(random_region)
49+
50+
4351
def test_get_jumpstart_launched_regions_message():
4452

4553
with patch("sagemaker.jumpstart.constants.JUMPSTART_REGION_NAME_SET", {}):

0 commit comments

Comments
 (0)