File tree 4 files changed +26
-3
lines changed
tests/unit/sagemaker/jumpstart
4 files changed +26
-3
lines changed Original file line number Diff line number Diff line change @@ -746,6 +746,7 @@ see `Model <https://sagemaker.readthedocs.io/en/stable/api/inference/model.html
746
746
.. code :: python
747
747
748
748
from sagemaker.model import Model
749
+ from sagemaker.predictor import Predictor
749
750
from sagemaker.session import Session
750
751
751
752
# Create the SageMaker model instance
@@ -755,6 +756,7 @@ see `Model <https://sagemaker.readthedocs.io/en/stable/api/inference/model.html
755
756
source_dir = script_uri,
756
757
entry_point = " inference.py" ,
757
758
role = Session().get_caller_identity_arn(),
759
+ predictor_cls = Predictor,
758
760
)
759
761
760
762
Save the output from deploying the model to a variable named
@@ -766,12 +768,9 @@ Deployment may take about 5 minutes.
766
768
767
769
.. code :: python
768
770
769
- from sagemaker.predictor import Predictor
770
-
771
771
predictor = model.deploy(
772
772
initial_instance_count = instance_count,
773
773
instance_type = instance_type,
774
- predictor_cls = Predictor
775
774
)
776
775
777
776
Because ``catboost `` and ``lightgbm `` rely on the PyTorch Deep Learning Containers
Original file line number Diff line number Diff line change 122
122
TRAINING_ENTRY_POINT_SCRIPT_NAME = "transfer_learning.py"
123
123
124
124
SUPPORTED_JUMPSTART_SCOPES = set (scope .value for scope in JumpStartScriptScope )
125
+
126
+ ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE = "AWS_JUMPSTART_CONTENT_BUCKET_OVERRIDE"
Original file line number Diff line number Diff line change 13
13
"""This module contains utilities related to SageMaker JumpStart."""
14
14
from __future__ import absolute_import
15
15
import logging
16
+ import os
16
17
from typing import Dict , List , Optional
17
18
from urllib .parse import urlparse
18
19
from packaging .version import Version
@@ -60,6 +61,14 @@ def get_jumpstart_content_bucket(region: str) -> str:
60
61
Raises:
61
62
RuntimeError: If JumpStart is not launched in ``region``.
62
63
"""
64
+
65
+ if (
66
+ constants .ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE in os .environ
67
+ and len (os .environ [constants .ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE ]) > 0
68
+ ):
69
+ bucket_override = os .environ [constants .ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE ]
70
+ LOGGER .info ("Using JumpStart bucket override: '%s'" , bucket_override )
71
+ return bucket_override
63
72
try :
64
73
return constants .JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT [region ].content_bucket
65
74
except KeyError :
Original file line number Diff line number Diff line change 11
11
# ANY KIND, either express or implied. See the License for the specific
12
12
# language governing permissions and limitations under the License.
13
13
from __future__ import absolute_import
14
+ import os
14
15
from mock .mock import Mock , patch
15
16
import pytest
16
17
import random
17
18
from sagemaker .jumpstart import utils
18
19
from sagemaker .jumpstart .constants import (
20
+ ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE ,
19
21
JUMPSTART_BUCKET_NAME_SET ,
20
22
JUMPSTART_REGION_NAME_SET ,
21
23
JumpStartScriptScope ,
@@ -40,6 +42,17 @@ def test_get_jumpstart_content_bucket():
40
42
utils .get_jumpstart_content_bucket (bad_region )
41
43
42
44
45
+ def test_get_jumpstart_content_bucket_override ():
46
+ with patch .dict (os .environ , {ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE : "some-val" }):
47
+ with patch ("logging.Logger.info" ) as mocked_info_log :
48
+ random_region = "random_region"
49
+ assert "some-val" == utils .get_jumpstart_content_bucket (random_region )
50
+ mocked_info_log .assert_called_once_with (
51
+ "Using JumpStart bucket override: '%s'" ,
52
+ "some-val" ,
53
+ )
54
+
55
+
43
56
def test_get_jumpstart_launched_regions_message ():
44
57
45
58
with patch ("sagemaker.jumpstart.constants.JUMPSTART_REGION_NAME_SET" , {}):
You can’t perform that action at this time.
0 commit comments