Skip to content

Commit 226be67

Browse files
evakravibenieric
authored and
root
committed
fix: sagemaker session region not being used (aws#4469)
* fix: sagemaker session region not being used * chore: add unit tests * fix: remove all JUMPSTART_DEFAULT_REGION_NAME default arguments * chore: use get_region_fallback throughout * chore: remove unnecessary if statement * chore: remove unnecessary if statement (2) --------- Co-authored-by: Erick Benitez-Ramos <[email protected]>
1 parent 2754a13 commit 226be67

File tree

40 files changed

+437
-181
lines changed

40 files changed

+437
-181
lines changed

src/sagemaker/jumpstart/artifacts/environment_variables.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from typing import Callable, Dict, Optional, Set
1616
from sagemaker.jumpstart.constants import (
1717
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
18-
JUMPSTART_DEFAULT_REGION_NAME,
1918
JUMPSTART_LOGGER,
2019
SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY,
2120
)
@@ -24,6 +23,7 @@
2423
)
2524
from sagemaker.jumpstart.utils import (
2625
get_jumpstart_gated_content_bucket,
26+
get_region_fallback,
2727
verify_model_region_and_return_specs,
2828
)
2929
from sagemaker.session import Session
@@ -72,8 +72,9 @@ def _retrieve_default_environment_variables(
7272
dict: the inference environment variables to use for the model.
7373
"""
7474

75-
if region is None:
76-
region = JUMPSTART_DEFAULT_REGION_NAME
75+
region = region or get_region_fallback(
76+
sagemaker_session=sagemaker_session,
77+
)
7778

7879
model_specs = verify_model_region_and_return_specs(
7980
model_id=model_id,
@@ -198,8 +199,9 @@ def _retrieve_gated_model_uri_env_var_value(
198199
ValueError: If the model specs specified are invalid.
199200
"""
200201

201-
if region is None:
202-
region = JUMPSTART_DEFAULT_REGION_NAME
202+
region = region or get_region_fallback(
203+
sagemaker_session=sagemaker_session,
204+
)
203205

204206
model_specs = verify_model_region_and_return_specs(
205207
model_id=model_id,

src/sagemaker/jumpstart/artifacts/hyperparameters.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
from typing import Dict, Optional
1616
from sagemaker.jumpstart.constants import (
1717
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
18-
JUMPSTART_DEFAULT_REGION_NAME,
1918
)
2019
from sagemaker.jumpstart.enums import (
2120
JumpStartScriptScope,
2221
VariableScope,
2322
)
2423
from sagemaker.jumpstart.utils import (
24+
get_region_fallback,
2525
verify_model_region_and_return_specs,
2626
)
2727
from sagemaker.session import Session
@@ -70,8 +70,9 @@ def _retrieve_default_hyperparameters(
7070
dict: the hyperparameters to use for the model.
7171
"""
7272

73-
if region is None:
74-
region = JUMPSTART_DEFAULT_REGION_NAME
73+
region = region or get_region_fallback(
74+
sagemaker_session=sagemaker_session,
75+
)
7576

7677
model_specs = verify_model_region_and_return_specs(
7778
model_id=model_id,

src/sagemaker/jumpstart/artifacts/image_uris.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717
from sagemaker import image_uris
1818
from sagemaker.jumpstart.constants import (
1919
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
20-
JUMPSTART_DEFAULT_REGION_NAME,
2120
)
2221
from sagemaker.jumpstart.enums import (
2322
JumpStartScriptScope,
2423
ModelFramework,
2524
)
2625
from sagemaker.jumpstart.utils import (
26+
get_region_fallback,
2727
verify_model_region_and_return_specs,
2828
)
2929
from sagemaker.session import Session
@@ -104,8 +104,9 @@ def _retrieve_image_uri(
104104
known security vulnerabilities.
105105
DeprecatedJumpStartModelError: If the version of the model is deprecated.
106106
"""
107-
if region is None:
108-
region = JUMPSTART_DEFAULT_REGION_NAME
107+
region = region or get_region_fallback(
108+
sagemaker_session=sagemaker_session,
109+
)
109110

110111
model_specs = verify_model_region_and_return_specs(
111112
model_id=model_id,

src/sagemaker/jumpstart/artifacts/incremental_training.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@
1515
from typing import Optional
1616
from sagemaker.jumpstart.constants import (
1717
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
18-
JUMPSTART_DEFAULT_REGION_NAME,
1918
)
2019
from sagemaker.jumpstart.enums import (
2120
JumpStartScriptScope,
2221
)
2322
from sagemaker.jumpstart.utils import (
23+
get_region_fallback,
2424
verify_model_region_and_return_specs,
2525
)
2626
from sagemaker.session import Session
@@ -58,8 +58,9 @@ def _model_supports_incremental_training(
5858
bool: the support status for incremental training.
5959
"""
6060

61-
if region is None:
62-
region = JUMPSTART_DEFAULT_REGION_NAME
61+
region = region or get_region_fallback(
62+
sagemaker_session=sagemaker_session,
63+
)
6364

6465
model_specs = verify_model_region_and_return_specs(
6566
model_id=model_id,

src/sagemaker/jumpstart/artifacts/instance_types.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@
1818
from sagemaker.jumpstart.exceptions import NO_AVAILABLE_INSTANCES_ERROR_MSG
1919
from sagemaker.jumpstart.constants import (
2020
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
21-
JUMPSTART_DEFAULT_REGION_NAME,
2221
)
2322
from sagemaker.jumpstart.enums import (
2423
JumpStartScriptScope,
2524
JumpStartModelType,
2625
)
2726
from sagemaker.jumpstart.utils import (
27+
get_region_fallback,
2828
verify_model_region_and_return_specs,
2929
)
3030
from sagemaker.session import Session
@@ -76,8 +76,9 @@ def _retrieve_default_instance_type(
7676
specified region due to lack of supported computing instances.
7777
"""
7878

79-
if region is None:
80-
region = JUMPSTART_DEFAULT_REGION_NAME
79+
region = region or get_region_fallback(
80+
sagemaker_session=sagemaker_session,
81+
)
8182

8283
model_specs = verify_model_region_and_return_specs(
8384
model_id=model_id,
@@ -163,8 +164,9 @@ def _retrieve_instance_types(
163164
specified region due to lack of supported computing instances.
164165
"""
165166

166-
if region is None:
167-
region = JUMPSTART_DEFAULT_REGION_NAME
167+
region = region or get_region_fallback(
168+
sagemaker_session=sagemaker_session,
169+
)
168170

169171
model_specs = verify_model_region_and_return_specs(
170172
model_id=model_id,

src/sagemaker/jumpstart/artifacts/kwargs.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@
1818
from sagemaker.utils import volume_size_supported
1919
from sagemaker.jumpstart.constants import (
2020
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
21-
JUMPSTART_DEFAULT_REGION_NAME,
2221
)
2322
from sagemaker.jumpstart.enums import (
2423
JumpStartScriptScope,
2524
JumpStartModelType,
2625
)
2726
from sagemaker.jumpstart.utils import (
27+
get_region_fallback,
2828
verify_model_region_and_return_specs,
2929
)
3030

@@ -62,8 +62,9 @@ def _retrieve_model_init_kwargs(
6262
dict: the kwargs to use for the use case.
6363
"""
6464

65-
if region is None:
66-
region = JUMPSTART_DEFAULT_REGION_NAME
65+
region = region or get_region_fallback(
66+
sagemaker_session=sagemaker_session,
67+
)
6768

6869
model_specs = verify_model_region_and_return_specs(
6970
model_id=model_id,
@@ -121,8 +122,9 @@ def _retrieve_model_deploy_kwargs(
121122
dict: the kwargs to use for the use case.
122123
"""
123124

124-
if region is None:
125-
region = JUMPSTART_DEFAULT_REGION_NAME
125+
region = region or get_region_fallback(
126+
sagemaker_session=sagemaker_session,
127+
)
126128

127129
model_specs = verify_model_region_and_return_specs(
128130
model_id=model_id,
@@ -176,8 +178,9 @@ def _retrieve_estimator_init_kwargs(
176178
dict: the kwargs to use for the use case.
177179
"""
178180

179-
if region is None:
180-
region = JUMPSTART_DEFAULT_REGION_NAME
181+
region = region or get_region_fallback(
182+
sagemaker_session=sagemaker_session,
183+
)
181184

182185
model_specs = verify_model_region_and_return_specs(
183186
model_id=model_id,
@@ -233,8 +236,9 @@ def _retrieve_estimator_fit_kwargs(
233236
dict: the kwargs to use for the use case.
234237
"""
235238

236-
if region is None:
237-
region = JUMPSTART_DEFAULT_REGION_NAME
239+
region = region or get_region_fallback(
240+
sagemaker_session=sagemaker_session,
241+
)
238242

239243
model_specs = verify_model_region_and_return_specs(
240244
model_id=model_id,

src/sagemaker/jumpstart/artifacts/metric_definitions.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616
from typing import Dict, List, Optional
1717
from sagemaker.jumpstart.constants import (
1818
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
19-
JUMPSTART_DEFAULT_REGION_NAME,
2019
)
2120
from sagemaker.jumpstart.enums import (
2221
JumpStartScriptScope,
2322
)
2423
from sagemaker.jumpstart.utils import (
24+
get_region_fallback,
2525
verify_model_region_and_return_specs,
2626
)
2727
from sagemaker.session import Session
@@ -62,8 +62,9 @@ def _retrieve_default_training_metric_definitions(
6262
list: the default training metric definitions to use for the model or None.
6363
"""
6464

65-
if region is None:
66-
region = JUMPSTART_DEFAULT_REGION_NAME
65+
region = region or get_region_fallback(
66+
sagemaker_session=sagemaker_session,
67+
)
6768

6869
model_specs = verify_model_region_and_return_specs(
6970
model_id=model_id,

src/sagemaker/jumpstart/artifacts/model_packages.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
from typing import Optional
1616
from sagemaker.jumpstart.constants import (
1717
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
18-
JUMPSTART_DEFAULT_REGION_NAME,
1918
)
2019
from sagemaker.jumpstart.utils import (
20+
get_region_fallback,
2121
verify_model_region_and_return_specs,
2222
)
2323
from sagemaker.jumpstart.enums import (
@@ -65,8 +65,9 @@ def _retrieve_model_package_arn(
6565
str: the model package arn to use for the model or None.
6666
"""
6767

68-
if region is None:
69-
region = JUMPSTART_DEFAULT_REGION_NAME
68+
region = region or get_region_fallback(
69+
sagemaker_session=sagemaker_session,
70+
)
7071

7172
model_specs = verify_model_region_and_return_specs(
7273
model_id=model_id,
@@ -149,8 +150,9 @@ def _retrieve_model_package_model_artifact_s3_uri(
149150

150151
if scope == JumpStartScriptScope.TRAINING:
151152

152-
if region is None:
153-
region = JUMPSTART_DEFAULT_REGION_NAME
153+
region = region or get_region_fallback(
154+
sagemaker_session=sagemaker_session,
155+
)
154156

155157
model_specs = verify_model_region_and_return_specs(
156158
model_id=model_id,

src/sagemaker/jumpstart/artifacts/model_uris.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@
1818
from sagemaker.jumpstart.constants import (
1919
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
2020
ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE,
21-
JUMPSTART_DEFAULT_REGION_NAME,
2221
)
2322
from sagemaker.jumpstart.enums import (
2423
JumpStartScriptScope,
2524
)
2625
from sagemaker.jumpstart.utils import (
2726
get_jumpstart_content_bucket,
2827
get_jumpstart_gated_content_bucket,
28+
get_region_fallback,
2929
verify_model_region_and_return_specs,
3030
)
3131
from sagemaker.session import Session
@@ -129,8 +129,9 @@ def _retrieve_model_uri(
129129
known security vulnerabilities.
130130
DeprecatedJumpStartModelError: If the version of the model is deprecated.
131131
"""
132-
if region is None:
133-
region = JUMPSTART_DEFAULT_REGION_NAME
132+
region = region or get_region_fallback(
133+
sagemaker_session=sagemaker_session,
134+
)
134135

135136
model_specs = verify_model_region_and_return_specs(
136137
model_id=model_id,
@@ -206,8 +207,9 @@ def _model_supports_training_model_uri(
206207
bool: the support status for model uri with training.
207208
"""
208209

209-
if region is None:
210-
region = JUMPSTART_DEFAULT_REGION_NAME
210+
region = region or get_region_fallback(
211+
sagemaker_session=sagemaker_session,
212+
)
211213

212214
model_specs = verify_model_region_and_return_specs(
213215
model_id=model_id,

src/sagemaker/jumpstart/artifacts/payloads.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@
1616
from typing import Dict, Optional
1717
from sagemaker.jumpstart.constants import (
1818
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
19-
JUMPSTART_DEFAULT_REGION_NAME,
2019
)
2120
from sagemaker.jumpstart.enums import (
2221
JumpStartScriptScope,
2322
JumpStartModelType,
2423
)
2524
from sagemaker.jumpstart.types import JumpStartSerializablePayload
2625
from sagemaker.jumpstart.utils import (
26+
get_region_fallback,
2727
verify_model_region_and_return_specs,
2828
)
2929
from sagemaker.session import Session
@@ -63,8 +63,9 @@ def _retrieve_example_payloads(
6363
to the serializable payload object.
6464
"""
6565

66-
if region is None:
67-
region = JUMPSTART_DEFAULT_REGION_NAME
66+
region = region or get_region_fallback(
67+
sagemaker_session=sagemaker_session,
68+
)
6869

6970
model_specs = verify_model_region_and_return_specs(
7071
model_id=model_id,

0 commit comments

Comments
 (0)