File tree 3 files changed +17
-0
lines changed
tests/unit/sagemaker/jumpstart
3 files changed +17
-0
lines changed Original file line number Diff line number Diff line change 122
122
TRAINING_ENTRY_POINT_SCRIPT_NAME = "transfer_learning.py"
123
123
124
124
SUPPORTED_JUMPSTART_SCOPES = set (scope .value for scope in JumpStartScriptScope )
125
+
126
+ ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE = "AWS_JUMPSTART_CONTENT_BUCKET_OVERRIDE"
Original file line number Diff line number Diff line change 13
13
"""This module contains utilities related to SageMaker JumpStart."""
14
14
from __future__ import absolute_import
15
15
import logging
16
+ import os
16
17
from typing import Dict , List , Optional
17
18
from urllib .parse import urlparse
18
19
from packaging .version import Version
@@ -60,6 +61,12 @@ def get_jumpstart_content_bucket(region: str) -> str:
60
61
Raises:
61
62
RuntimeError: If JumpStart is not launched in ``region``.
62
63
"""
64
+
65
+ if (
66
+ constants .ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE in os .environ
67
+ and len (os .environ [constants .ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE ]) > 0
68
+ ):
69
+ return os .environ [constants .ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE ]
63
70
try :
64
71
return constants .JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT [region ].content_bucket
65
72
except KeyError :
Original file line number Diff line number Diff line change 11
11
# ANY KIND, either express or implied. See the License for the specific
12
12
# language governing permissions and limitations under the License.
13
13
from __future__ import absolute_import
14
+ import os
14
15
from mock .mock import Mock , patch
15
16
import pytest
16
17
import random
17
18
from sagemaker .jumpstart import utils
18
19
from sagemaker .jumpstart .constants import (
20
+ ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE ,
19
21
JUMPSTART_BUCKET_NAME_SET ,
20
22
JUMPSTART_REGION_NAME_SET ,
21
23
JumpStartScriptScope ,
@@ -40,6 +42,12 @@ def test_get_jumpstart_content_bucket():
40
42
utils .get_jumpstart_content_bucket (bad_region )
41
43
42
44
45
+ def test_get_jumpstart_content_bucket_override ():
46
+ with patch .dict (os .environ , {ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE : "some-val" }):
47
+ random_region = "random_region"
48
+ assert "some-val" == utils .get_jumpstart_content_bucket (random_region )
49
+
50
+
43
51
def test_get_jumpstart_launched_regions_message ():
44
52
45
53
with patch ("sagemaker.jumpstart.constants.JUMPSTART_REGION_NAME_SET" , {}):
You can’t perform that action at this time.
0 commit comments