diff --git a/doc/overview.rst b/doc/overview.rst index 39f5f6ecae..df320e3b47 100644 --- a/doc/overview.rst +++ b/doc/overview.rst @@ -746,6 +746,7 @@ see `Model str: Raises: RuntimeError: If JumpStart is not launched in ``region``. """ + + if ( + constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE in os.environ + and len(os.environ[constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE]) > 0 + ): + bucket_override = os.environ[constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE] + LOGGER.info("Using JumpStart bucket override: '%s'", bucket_override) + return bucket_override try: return constants.JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT[region].content_bucket except KeyError: diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index fe494eb459..04eddced08 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -11,11 +11,13 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import +import os from mock.mock import Mock, patch import pytest import random from sagemaker.jumpstart import utils from sagemaker.jumpstart.constants import ( + ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE, JUMPSTART_BUCKET_NAME_SET, JUMPSTART_REGION_NAME_SET, JumpStartScriptScope, @@ -40,6 +42,17 @@ def test_get_jumpstart_content_bucket(): utils.get_jumpstart_content_bucket(bad_region) +def test_get_jumpstart_content_bucket_override(): + with patch.dict(os.environ, {ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE: "some-val"}): + with patch("logging.Logger.info") as mocked_info_log: + random_region = "random_region" + assert "some-val" == utils.get_jumpstart_content_bucket(random_region) + mocked_info_log.assert_called_once_with( + "Using JumpStart bucket override: '%s'", + "some-val", + ) + + def test_get_jumpstart_launched_regions_message(): with patch("sagemaker.jumpstart.constants.JUMPSTART_REGION_NAME_SET", {}):