Skip to content

Commit c098e98

Browse files
evakraviknikure
authored andcommitted
fix: Gated content bucket env var override (aws#1280)
1 parent 5db63f5 commit c098e98

File tree

3 files changed

+8
-6
lines changed

3 files changed

+8
-6
lines changed

src/sagemaker/jumpstart/constants.py

+1
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@
168168
SUPPORTED_JUMPSTART_SCOPES = set(scope.value for scope in JumpStartScriptScope)
169169

170170
ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE = "AWS_JUMPSTART_CONTENT_BUCKET_OVERRIDE"
171+
ENV_VARIABLE_JUMPSTART_GATED_CONTENT_BUCKET_OVERRIDE = "AWS_JUMPSTART_GATED_CONTENT_BUCKET_OVERRIDE"
171172
ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE = "AWS_JUMPSTART_MODEL_BUCKET_OVERRIDE"
172173
ENV_VARIABLE_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE = "AWS_JUMPSTART_SCRIPT_BUCKET_OVERRIDE"
173174
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE = (

src/sagemaker/jumpstart/utils.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,13 @@ def get_jumpstart_gated_content_bucket(
8181

8282
gated_bucket_to_return: Optional[str] = None
8383
if (
84-
constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE in os.environ
85-
and len(os.environ[constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE]) > 0
84+
constants.ENV_VARIABLE_JUMPSTART_GATED_CONTENT_BUCKET_OVERRIDE in os.environ
85+
and len(os.environ[constants.ENV_VARIABLE_JUMPSTART_GATED_CONTENT_BUCKET_OVERRIDE]) > 0
8686
):
8787
gated_bucket_to_return = os.environ[
88-
constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE
88+
constants.ENV_VARIABLE_JUMPSTART_GATED_CONTENT_BUCKET_OVERRIDE
8989
]
90-
info_logs.append(f"Using JumpStart private bucket override: '{gated_bucket_to_return}'")
90+
info_logs.append(f"Using JumpStart gated bucket override: '{gated_bucket_to_return}'")
9191
else:
9292
try:
9393
gated_bucket_to_return = constants.JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT[

tests/unit/sagemaker/jumpstart/test_utils.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from sagemaker.jumpstart.constants import (
2121
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
2222
ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE,
23+
ENV_VARIABLE_JUMPSTART_GATED_CONTENT_BUCKET_OVERRIDE,
2324
JUMPSTART_DEFAULT_REGION_NAME,
2425
JUMPSTART_GATED_AND_PUBLIC_BUCKET_NAME_SET,
2526
JUMPSTART_REGION_NAME_SET,
@@ -78,12 +79,12 @@ def test_get_jumpstart_gated_content_bucket_no_args():
7879

7980

8081
def test_get_jumpstart_gated_content_bucket_override():
81-
with patch.dict(os.environ, {ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE: "some-val"}):
82+
with patch.dict(os.environ, {ENV_VARIABLE_JUMPSTART_GATED_CONTENT_BUCKET_OVERRIDE: "some-val"}):
8283
with patch("logging.Logger.info") as mocked_info_log:
8384
random_region = "random_region"
8485
assert "some-val" == utils.get_jumpstart_gated_content_bucket(random_region)
8586
mocked_info_log.assert_called_once_with(
86-
"Using JumpStart private bucket override: 'some-val'"
87+
"Using JumpStart gated bucket override: 'some-val'"
8788
)
8889

8990

0 commit comments

Comments
 (0)