Skip to content

Commit 8d95d30

Browse files
committed
fix: jumpstart works without region in aws config
1 parent 42743aa commit 8d95d30

File tree

5 files changed

+16
-9
lines changed

5 files changed

+16
-9
lines changed

src/sagemaker/jumpstart/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@
112112
}
113113
JUMPSTART_REGION_NAME_SET = {region.region_name for region in JUMPSTART_LAUNCHED_REGIONS}
114114

115-
JUMPSTART_DEFAULT_REGION_NAME = boto3.session.Session().region_name
115+
JUMPSTART_DEFAULT_REGION_NAME = boto3.session.Session().region_name or "us-west-2"
116116

117117
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"
118118

tests/integ/sagemaker/jumpstart/retrieve_uri/conftest.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
JUMPSTART_TAG,
2828
)
2929

30+
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
31+
3032

3133
def _setup():
3234
print("Setting up...")
@@ -41,7 +43,9 @@ def _teardown():
4143
test_suite_id = os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]
4244

4345
sagemaker_client = boto3.client(
44-
"sagemaker", config=Config(retries={"max_attempts": 10, "mode": "standard"})
46+
"sagemaker",
47+
config=Config(retries={"max_attempts": 10, "mode": "standard"}),
48+
region_name=JUMPSTART_DEFAULT_REGION_NAME,
4549
)
4650

4751
search_endpoints_result = sagemaker_client.search(

tests/integ/sagemaker/jumpstart/retrieve_uri/inference.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
2424
from tests.integ.sagemaker.jumpstart.retrieve_uri.utils import (
2525
get_test_artifact_bucket,
26+
get_sm_session,
2627
)
27-
from sagemaker.session import Session
2828

2929
from sagemaker.utils import repack_model
3030
from sagemaker.model import (
@@ -59,7 +59,7 @@ def __init__(
5959
self.region = region
6060
self.config = boto_config
6161
self.base_name = base_name
62-
self.execution_role = execution_role or Session().get_caller_identity_arn()
62+
self.execution_role = execution_role or get_sm_session().get_caller_identity_arn()
6363
self.account_id = boto3.client("sts").get_caller_identity()["Account"]
6464
self.image_uri = image_uri
6565
self.script_uri = script_uri
@@ -102,7 +102,7 @@ def package_artifacts(self):
102102
dependencies=None,
103103
model_uri=self.model_uri,
104104
repacked_model_uri=repacked_model_uri,
105-
sagemaker_session=Session(),
105+
sagemaker_session=get_sm_session(),
106106
kms_key=None,
107107
)
108108

tests/integ/sagemaker/jumpstart/retrieve_uri/training.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,14 @@
2121
from tests.integ.sagemaker.jumpstart.retrieve_uri.utils import (
2222
get_full_hyperparameters,
2323
get_test_artifact_bucket,
24+
get_sm_session,
2425
)
2526
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
2627

2728
from tests.integ.sagemaker.jumpstart.retrieve_uri.constants import (
2829
ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID,
2930
)
3031

31-
from sagemaker.session import Session
32-
3332

3433
class TrainingJobLauncher:
3534
def __init__(
@@ -53,7 +52,7 @@ def __init__(
5352
self.region = region
5453
self.config = boto_config
5554
self.base_name = base_name
56-
self.execution_role = execution_role or Session().get_caller_identity_arn()
55+
self.execution_role = execution_role or get_sm_session().get_caller_identity_arn()
5756
self.image_uri = image_uri
5857
self.script_uri = script_uri
5958
self.model_uri = model_uri

tests/integ/sagemaker/jumpstart/retrieve_uri/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,12 @@ def get_training_dataset_for_model_and_version(model_id: str, version: str) -> d
6767
return TRAINING_DATASET_MODEL_DICT[(model_id, version)]
6868

6969

70+
def get_sm_session() -> Session:
71+
return Session(boto_session=boto3.Session(region_name=JUMPSTART_DEFAULT_REGION_NAME))
72+
73+
7074
def get_test_artifact_bucket() -> str:
71-
bucket_name = Session().default_bucket()
75+
bucket_name = get_sm_session().default_bucket()
7276
return bucket_name
7377

7478

0 commit comments

Comments
 (0)