Skip to content

Commit 7279028

Browse files
committed
fix: remove all JUMPSTART_DEFAULT_REGION_NAME default arguments
1 parent 46efb3a commit 7279028

File tree

7 files changed

+134
-91
lines changed

7 files changed

+134
-91
lines changed

src/sagemaker/jumpstart/cache.py

+2-33
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import datetime
1616
from difflib import get_close_matches
1717
import os
18-
from typing import List, Optional, Set, Tuple, Union
18+
from typing import List, Optional, Tuple, Union
1919
import json
2020
import boto3
2121
import botocore
@@ -25,9 +25,7 @@
2525
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE,
2626
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE,
2727
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY,
28-
JUMPSTART_DEFAULT_REGION_NAME,
2928
JUMPSTART_LOGGER,
30-
JUMPSTART_REGION_NAME_SET,
3129
MODEL_ID_LIST_WEB_URL,
3230
)
3331
from sagemaker.jumpstart.exceptions import get_wildcard_model_version_msg
@@ -94,7 +92,7 @@ def __init__(
9492
s3_client (Optional[boto3.client]): s3 client to use. Default: None.
9593
"""
9694

97-
self._region = region or self._get_region_fallback(
95+
self._region = region or utils.get_region_fallback(
9896
s3_bucket_name=s3_bucket_name, s3_client=s3_client
9997
)
10098

@@ -122,35 +120,6 @@ def __init__(
122120
else boto3.client("s3", region_name=self._region)
123121
)
124122

125-
def _get_region_fallback(
126-
self, s3_bucket_name: Optional[str], s3_client: Optional[boto3.client]
127-
) -> str:
128-
"""Returns region to use throughout cache in the absence of one specified in constructor."""
129-
regions_in_s3_bucket_name: Set[str] = {
130-
region
131-
for region in JUMPSTART_REGION_NAME_SET
132-
if s3_bucket_name is not None
133-
if region in s3_bucket_name
134-
}
135-
regions_in_s3_client_endpoint_url: Set[str] = {
136-
region
137-
for region in JUMPSTART_REGION_NAME_SET
138-
if s3_client is not None
139-
if region in s3_client._endpoint.host
140-
}
141-
142-
combined_regions = regions_in_s3_client_endpoint_url.union(regions_in_s3_bucket_name)
143-
144-
if len(combined_regions) > 1:
145-
raise ValueError(
146-
"Unable to resolve a region name from the s3 bucket and client provided."
147-
)
148-
149-
if len(combined_regions) == 0:
150-
return JUMPSTART_DEFAULT_REGION_NAME
151-
152-
return list(combined_regions)[0]
153-
154123
def set_region(self, region: str) -> None:
155124
"""Set region for cache. Clears cache after new region is set."""
156125
if region != self._region:

src/sagemaker/jumpstart/notebook_utils.py

+32-13
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from sagemaker.jumpstart import accessors
2424
from sagemaker.jumpstart.constants import (
2525
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
26-
JUMPSTART_DEFAULT_REGION_NAME,
2726
)
2827
from sagemaker.jumpstart.enums import JumpStartScriptScope
2928
from sagemaker.jumpstart.filters import (
@@ -36,6 +35,7 @@
3635
from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs
3736
from sagemaker.jumpstart.utils import (
3837
get_jumpstart_content_bucket,
38+
get_region_fallback,
3939
get_sagemaker_version,
4040
verify_model_region_and_return_specs,
4141
)
@@ -143,7 +143,7 @@ def extract_framework_task_model(model_id: str) -> Tuple[str, str, str]:
143143

144144
def list_jumpstart_tasks( # pylint: disable=redefined-builtin
145145
filter: Union[Operator, str] = Constant(BooleanValues.TRUE),
146-
region: str = JUMPSTART_DEFAULT_REGION_NAME,
146+
region: Optional[str] = None,
147147
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
148148
) -> List[str]:
149149
"""List tasks for JumpStart, and optionally apply filters to result.
@@ -155,11 +155,14 @@ def list_jumpstart_tasks( # pylint: disable=redefined-builtin
155155
(e.g. ``"task == ic"``). If this argument is not supplied, all tasks will be listed.
156156
(Default: Constant(BooleanValues.TRUE)).
157157
region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding
158-
models. (Default: JUMPSTART_DEFAULT_REGION_NAME).
158+
models. (Default: None).
159159
sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to
160160
use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
161161
"""
162162

163+
region = region or get_region_fallback(
164+
sagemaker_session=sagemaker_session,
165+
)
163166
tasks: Set[str] = set()
164167
for model_id, _ in _generate_jumpstart_model_versions(
165168
filter=filter, region=region, sagemaker_session=sagemaker_session
@@ -171,7 +174,7 @@ def list_jumpstart_tasks( # pylint: disable=redefined-builtin
171174

172175
def list_jumpstart_frameworks( # pylint: disable=redefined-builtin
173176
filter: Union[Operator, str] = Constant(BooleanValues.TRUE),
174-
region: str = JUMPSTART_DEFAULT_REGION_NAME,
177+
region: Optional[str] = None,
175178
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
176179
) -> List[str]:
177180
"""List frameworks for JumpStart, and optionally apply filters to result.
@@ -183,11 +186,14 @@ def list_jumpstart_frameworks( # pylint: disable=redefined-builtin
183186
(eg. ``"task == ic"``). If this argument is not supplied, all frameworks will be listed.
184187
(Default: Constant(BooleanValues.TRUE)).
185188
region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding
186-
models. (Default: JUMPSTART_DEFAULT_REGION_NAME).
189+
models. (Default: None).
187190
sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session
188191
to use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
189192
"""
190193

194+
region = region or get_region_fallback(
195+
sagemaker_session=sagemaker_session,
196+
)
191197
frameworks: Set[str] = set()
192198
for model_id, _ in _generate_jumpstart_model_versions(
193199
filter=filter, region=region, sagemaker_session=sagemaker_session
@@ -199,7 +205,7 @@ def list_jumpstart_frameworks( # pylint: disable=redefined-builtin
199205

200206
def list_jumpstart_scripts( # pylint: disable=redefined-builtin
201207
filter: Union[Operator, str] = Constant(BooleanValues.TRUE),
202-
region: str = JUMPSTART_DEFAULT_REGION_NAME,
208+
region: Optional[str] = None,
203209
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
204210
) -> List[str]:
205211
"""List scripts for JumpStart, and optionally apply filters to result.
@@ -211,10 +217,13 @@ def list_jumpstart_scripts( # pylint: disable=redefined-builtin
211217
(e.g. ``"task == ic"``). If this argument is not supplied, all scripts will be listed.
212218
(Default: Constant(BooleanValues.TRUE)).
213219
region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding
214-
models. (Default: JUMPSTART_DEFAULT_REGION_NAME).
220+
models. (Default: None).
215221
sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to
216222
use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
217223
"""
224+
region = region or get_region_fallback(
225+
sagemaker_session=sagemaker_session,
226+
)
218227
if (isinstance(filter, Constant) and filter.resolved_value == BooleanValues.TRUE) or (
219228
isinstance(filter, str) and filter.lower() == BooleanValues.TRUE.lower()
220229
):
@@ -242,7 +251,7 @@ def list_jumpstart_scripts( # pylint: disable=redefined-builtin
242251

243252
def list_jumpstart_models( # pylint: disable=redefined-builtin
244253
filter: Union[Operator, str] = Constant(BooleanValues.TRUE),
245-
region: str = JUMPSTART_DEFAULT_REGION_NAME,
254+
region: Optional[str] = None,
246255
list_incomplete_models: bool = False,
247256
list_old_models: bool = False,
248257
list_versions: bool = False,
@@ -257,7 +266,7 @@ def list_jumpstart_models( # pylint: disable=redefined-builtin
257266
(e.g. ``"task == ic"``). If this argument is not supplied, all models will be listed.
258267
(Default: Constant(BooleanValues.TRUE)).
259268
region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding
260-
models. (Default: JUMPSTART_DEFAULT_REGION_NAME).
269+
models. (Default: None).
261270
list_incomplete_models (bool): Optional. If a model does not contain metadata fields
262271
requested by the filter, and the filter cannot be resolved to a include/not include,
263272
whether the model should be included. By default, these models are omitted from results.
@@ -270,6 +279,9 @@ def list_jumpstart_models( # pylint: disable=redefined-builtin
270279
to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
271280
"""
272281

282+
region = region or get_region_fallback(
283+
sagemaker_session=sagemaker_session,
284+
)
273285
model_id_version_dict: Dict[str, List[str]] = dict()
274286
for model_id, version in _generate_jumpstart_model_versions(
275287
filter=filter,
@@ -299,7 +311,7 @@ def list_jumpstart_models( # pylint: disable=redefined-builtin
299311

300312
def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin
301313
filter: Union[Operator, str] = Constant(BooleanValues.TRUE),
302-
region: str = JUMPSTART_DEFAULT_REGION_NAME,
314+
region: Optional[str] = None,
303315
list_incomplete_models: bool = False,
304316
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
305317
) -> Generator:
@@ -312,7 +324,7 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin
312324
(e.g. ``"task == ic"``). If this argument is not supplied, all models will be generated.
313325
(Default: Constant(BooleanValues.TRUE)).
314326
region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding
315-
models. (Default: JUMPSTART_DEFAULT_REGION_NAME).
327+
models. (Default: None).
316328
list_incomplete_models (bool): Optional. If a model does not contain metadata fields
317329
requested by the filter, and the filter cannot be resolved to a include/not include,
318330
whether the model should be included. By default, these models are omitted from
@@ -321,6 +333,10 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin
321333
to use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
322334
"""
323335

336+
region = region or get_region_fallback(
337+
sagemaker_session=sagemaker_session,
338+
)
339+
324340
models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest(
325341
region=region, s3_client=sagemaker_session.s3_client
326342
)
@@ -453,7 +469,7 @@ def evaluate_model(model_manifest: JumpStartModelHeader) -> Optional[Tuple[str,
453469
def get_model_url(
454470
model_id: str,
455471
model_version: str,
456-
region: str = JUMPSTART_DEFAULT_REGION_NAME,
472+
region: Optional[str] = None,
457473
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
458474
) -> str:
459475
"""Retrieve web url describing pretrained model.
@@ -462,11 +478,14 @@ def get_model_url(
462478
model_id (str): The model ID for which to retrieve the url.
463479
model_version (str): The model version for which to retrieve the url.
464480
region (str): Optional. The region from which to retrieve metadata.
465-
(Default: JUMPSTART_DEFAULT_REGION_NAME)
481+
(Default: None)
466482
sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use
467483
to retrieve the model url.
468484
"""
469485

486+
region = region or get_region_fallback(
487+
sagemaker_session=sagemaker_session,
488+
)
470489
model_specs = verify_model_region_and_return_specs(
471490
region=region,
472491
model_id=model_id,

src/sagemaker/jumpstart/payload_utils.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@
2222
from sagemaker.jumpstart.artifacts.payloads import _retrieve_example_payloads
2323
from sagemaker.jumpstart.constants import (
2424
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
25-
JUMPSTART_DEFAULT_REGION_NAME,
2625
)
2726
from sagemaker.jumpstart.enums import MIMEType
2827
from sagemaker.jumpstart.types import JumpStartSerializablePayload
2928
from sagemaker.jumpstart.utils import (
3029
get_jumpstart_content_bucket,
30+
get_region_fallback,
3131
)
3232
from sagemaker.session import Session
3333

@@ -125,12 +125,14 @@ class PayloadSerializer:
125125
def __init__(
126126
self,
127127
bucket: Optional[str] = None,
128-
region: str = JUMPSTART_DEFAULT_REGION_NAME,
128+
region: Optional[str] = None,
129129
s3_client: Optional[boto3.client] = None,
130130
) -> None:
131131
"""Initializes PayloadSerializer object."""
132132
self.bucket = bucket or get_jumpstart_content_bucket()
133-
self.region = region
133+
self.region = region or get_region_fallback(
134+
s3_client=s3_client,
135+
)
134136
self.s3_client = s3_client
135137

136138
def get_bytes_payload_with_s3_references(

src/sagemaker/jumpstart/utils.py

+40-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515
import logging
1616
import os
17-
from typing import Any, Dict, List, Optional, Tuple, Union
17+
from typing import Any, Dict, List, Optional, Set, Tuple, Union
1818
from urllib.parse import urlparse
1919
import boto3
2020
from packaging.version import Version
@@ -813,3 +813,42 @@ def get_jumpstart_model_id_version_from_resource_arn(
813813
model_version = model_version_from_tag
814814

815815
return model_id, model_version
816+
817+
818+
def get_region_fallback(
819+
s3_bucket_name: Optional[str] = None,
820+
s3_client: Optional[boto3.client] = None,
821+
sagemaker_session: Optional[Session] = None,
822+
) -> str:
823+
"""Returns region to use for JumpStart functionality implicitly via session objects."""
824+
regions_in_s3_bucket_name: Set[str] = {
825+
region
826+
for region in constants.JUMPSTART_REGION_NAME_SET
827+
if s3_bucket_name is not None
828+
if region in s3_bucket_name
829+
}
830+
regions_in_s3_client_endpoint_url: Set[str] = {
831+
region
832+
for region in constants.JUMPSTART_REGION_NAME_SET
833+
if s3_client is not None
834+
if region in s3_client._endpoint.host
835+
}
836+
837+
regions_in_sagemaker_session: Set[str] = {
838+
region
839+
for region in constants.JUMPSTART_REGION_NAME_SET
840+
if sagemaker_session
841+
if region == sagemaker_session.boto_region_name
842+
}
843+
844+
combined_regions = regions_in_s3_client_endpoint_url.union(
845+
regions_in_s3_bucket_name, regions_in_sagemaker_session
846+
)
847+
848+
if len(combined_regions) > 1:
849+
raise ValueError("Unable to resolve a region name from the s3 bucket and client provided.")
850+
851+
if len(combined_regions) == 0:
852+
return constants.JUMPSTART_DEFAULT_REGION_NAME
853+
854+
return list(combined_regions)[0]

src/sagemaker/jumpstart/validators.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from __future__ import absolute_import
1515
from typing import Any, Dict, List, Optional
1616
from sagemaker import session
17-
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
1817

1918
from sagemaker.jumpstart.enums import (
2019
HyperparameterValidationMode,
@@ -24,7 +23,7 @@
2423
)
2524
from sagemaker.jumpstart.exceptions import JumpStartHyperparametersError
2625
from sagemaker.jumpstart.types import JumpStartHyperparameter
27-
from sagemaker.jumpstart.utils import verify_model_region_and_return_specs
26+
from sagemaker.jumpstart.utils import get_region_fallback, verify_model_region_and_return_specs
2827

2928

3029
def _validate_hyperparameter(
@@ -168,7 +167,7 @@ def validate_hyperparameters(
168167
model_version: str,
169168
hyperparameters: Dict[str, Any],
170169
validation_mode: HyperparameterValidationMode = HyperparameterValidationMode.VALIDATE_PROVIDED,
171-
region: Optional[str] = JUMPSTART_DEFAULT_REGION_NAME,
170+
region: Optional[str] = None,
172171
sagemaker_session: Optional[session.Session] = None,
173172
tolerate_vulnerable_model: bool = False,
174173
tolerate_deprecated_model: bool = False,
@@ -184,8 +183,7 @@ def validate_hyperparameters(
184183
to this function will be validated, the missing hyperparameters will be ignored.
185184
If set to``VALIDATE_ALGORITHM``, all algorithm hyperparameters will be validated.
186185
If set to ``VALIDATE_ALL``, all hyperparameters for the model will be validated.
187-
region (str): Region for which to validate hyperparameters. (Default: JumpStart
188-
default region).
186+
region (str): Region for which to validate hyperparameters. (Default: None).
189187
sagemaker_session (Optional[Session]): Custom SageMaker Session to use.
190188
(Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
191189
tolerate_vulnerable_model (bool): True if vulnerable versions of model
@@ -202,6 +200,9 @@ def validate_hyperparameters(
202200
203201
"""
204202

203+
region = region or get_region_fallback(
204+
sagemaker_session=sagemaker_session,
205+
)
205206
if validation_mode is None:
206207
validation_mode = HyperparameterValidationMode.VALIDATE_PROVIDED
207208

0 commit comments

Comments
 (0)