23
23
from sagemaker .jumpstart import accessors
24
24
from sagemaker .jumpstart .constants import (
25
25
DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
26
- JUMPSTART_DEFAULT_REGION_NAME ,
27
26
)
28
27
from sagemaker .jumpstart .enums import JumpStartScriptScope
29
28
from sagemaker .jumpstart .filters import (
36
35
from sagemaker .jumpstart .types import JumpStartModelHeader , JumpStartModelSpecs
37
36
from sagemaker .jumpstart .utils import (
38
37
get_jumpstart_content_bucket ,
38
+ get_region_fallback ,
39
39
get_sagemaker_version ,
40
40
verify_model_region_and_return_specs ,
41
41
)
@@ -143,7 +143,7 @@ def extract_framework_task_model(model_id: str) -> Tuple[str, str, str]:
143
143
144
144
def list_jumpstart_tasks ( # pylint: disable=redefined-builtin
145
145
filter : Union [Operator , str ] = Constant (BooleanValues .TRUE ),
146
- region : str = JUMPSTART_DEFAULT_REGION_NAME ,
146
+ region : Optional [ str ] = None ,
147
147
sagemaker_session : Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
148
148
) -> List [str ]:
149
149
"""List tasks for JumpStart, and optionally apply filters to result.
@@ -155,11 +155,14 @@ def list_jumpstart_tasks( # pylint: disable=redefined-builtin
155
155
(e.g. ``"task == ic"``). If this argument is not supplied, all tasks will be listed.
156
156
(Default: Constant(BooleanValues.TRUE)).
157
157
region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding
158
- models. (Default: JUMPSTART_DEFAULT_REGION_NAME ).
158
+ models. (Default: None ).
159
159
sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to
160
160
use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
161
161
"""
162
162
163
+ region = region or get_region_fallback (
164
+ sagemaker_session = sagemaker_session ,
165
+ )
163
166
tasks : Set [str ] = set ()
164
167
for model_id , _ in _generate_jumpstart_model_versions (
165
168
filter = filter , region = region , sagemaker_session = sagemaker_session
@@ -171,7 +174,7 @@ def list_jumpstart_tasks( # pylint: disable=redefined-builtin
171
174
172
175
def list_jumpstart_frameworks ( # pylint: disable=redefined-builtin
173
176
filter : Union [Operator , str ] = Constant (BooleanValues .TRUE ),
174
- region : str = JUMPSTART_DEFAULT_REGION_NAME ,
177
+ region : Optional [ str ] = None ,
175
178
sagemaker_session : Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
176
179
) -> List [str ]:
177
180
"""List frameworks for JumpStart, and optionally apply filters to result.
@@ -183,11 +186,14 @@ def list_jumpstart_frameworks( # pylint: disable=redefined-builtin
183
186
(eg. ``"task == ic"``). If this argument is not supplied, all frameworks will be listed.
184
187
(Default: Constant(BooleanValues.TRUE)).
185
188
region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding
186
- models. (Default: JUMPSTART_DEFAULT_REGION_NAME ).
189
+ models. (Default: None ).
187
190
sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session
188
191
to use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
189
192
"""
190
193
194
+ region = region or get_region_fallback (
195
+ sagemaker_session = sagemaker_session ,
196
+ )
191
197
frameworks : Set [str ] = set ()
192
198
for model_id , _ in _generate_jumpstart_model_versions (
193
199
filter = filter , region = region , sagemaker_session = sagemaker_session
@@ -199,7 +205,7 @@ def list_jumpstart_frameworks( # pylint: disable=redefined-builtin
199
205
200
206
def list_jumpstart_scripts ( # pylint: disable=redefined-builtin
201
207
filter : Union [Operator , str ] = Constant (BooleanValues .TRUE ),
202
- region : str = JUMPSTART_DEFAULT_REGION_NAME ,
208
+ region : Optional [ str ] = None ,
203
209
sagemaker_session : Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
204
210
) -> List [str ]:
205
211
"""List scripts for JumpStart, and optionally apply filters to result.
@@ -211,10 +217,13 @@ def list_jumpstart_scripts( # pylint: disable=redefined-builtin
211
217
(e.g. ``"task == ic"``). If this argument is not supplied, all scripts will be listed.
212
218
(Default: Constant(BooleanValues.TRUE)).
213
219
region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding
214
- models. (Default: JUMPSTART_DEFAULT_REGION_NAME ).
220
+ models. (Default: None ).
215
221
sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to
216
222
use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
217
223
"""
224
+ region = region or get_region_fallback (
225
+ sagemaker_session = sagemaker_session ,
226
+ )
218
227
if (isinstance (filter , Constant ) and filter .resolved_value == BooleanValues .TRUE ) or (
219
228
isinstance (filter , str ) and filter .lower () == BooleanValues .TRUE .lower ()
220
229
):
@@ -242,7 +251,7 @@ def list_jumpstart_scripts( # pylint: disable=redefined-builtin
242
251
243
252
def list_jumpstart_models ( # pylint: disable=redefined-builtin
244
253
filter : Union [Operator , str ] = Constant (BooleanValues .TRUE ),
245
- region : str = JUMPSTART_DEFAULT_REGION_NAME ,
254
+ region : Optional [ str ] = None ,
246
255
list_incomplete_models : bool = False ,
247
256
list_old_models : bool = False ,
248
257
list_versions : bool = False ,
@@ -257,7 +266,7 @@ def list_jumpstart_models( # pylint: disable=redefined-builtin
257
266
(e.g. ``"task == ic"``). If this argument is not supplied, all models will be listed.
258
267
(Default: Constant(BooleanValues.TRUE)).
259
268
region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding
260
- models. (Default: JUMPSTART_DEFAULT_REGION_NAME ).
269
+ models. (Default: None ).
261
270
list_incomplete_models (bool): Optional. If a model does not contain metadata fields
262
271
requested by the filter, and the filter cannot be resolved to a include/not include,
263
272
whether the model should be included. By default, these models are omitted from results.
@@ -270,6 +279,9 @@ def list_jumpstart_models( # pylint: disable=redefined-builtin
270
279
to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
271
280
"""
272
281
282
+ region = region or get_region_fallback (
283
+ sagemaker_session = sagemaker_session ,
284
+ )
273
285
model_id_version_dict : Dict [str , List [str ]] = dict ()
274
286
for model_id , version in _generate_jumpstart_model_versions (
275
287
filter = filter ,
@@ -299,7 +311,7 @@ def list_jumpstart_models( # pylint: disable=redefined-builtin
299
311
300
312
def _generate_jumpstart_model_versions ( # pylint: disable=redefined-builtin
301
313
filter : Union [Operator , str ] = Constant (BooleanValues .TRUE ),
302
- region : str = JUMPSTART_DEFAULT_REGION_NAME ,
314
+ region : Optional [ str ] = None ,
303
315
list_incomplete_models : bool = False ,
304
316
sagemaker_session : Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
305
317
) -> Generator :
@@ -312,7 +324,7 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin
312
324
(e.g. ``"task == ic"``). If this argument is not supplied, all models will be generated.
313
325
(Default: Constant(BooleanValues.TRUE)).
314
326
region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding
315
- models. (Default: JUMPSTART_DEFAULT_REGION_NAME ).
327
+ models. (Default: None ).
316
328
list_incomplete_models (bool): Optional. If a model does not contain metadata fields
317
329
requested by the filter, and the filter cannot be resolved to a include/not include,
318
330
whether the model should be included. By default, these models are omitted from
@@ -321,6 +333,10 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin
321
333
to use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
322
334
"""
323
335
336
+ region = region or get_region_fallback (
337
+ sagemaker_session = sagemaker_session ,
338
+ )
339
+
324
340
models_manifest_list = accessors .JumpStartModelsAccessor ._get_manifest (
325
341
region = region , s3_client = sagemaker_session .s3_client
326
342
)
@@ -453,7 +469,7 @@ def evaluate_model(model_manifest: JumpStartModelHeader) -> Optional[Tuple[str,
453
469
def get_model_url (
454
470
model_id : str ,
455
471
model_version : str ,
456
- region : str = JUMPSTART_DEFAULT_REGION_NAME ,
472
+ region : Optional [ str ] = None ,
457
473
sagemaker_session : Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
458
474
) -> str :
459
475
"""Retrieve web url describing pretrained model.
@@ -462,11 +478,14 @@ def get_model_url(
462
478
model_id (str): The model ID for which to retrieve the url.
463
479
model_version (str): The model version for which to retrieve the url.
464
480
region (str): Optional. The region from which to retrieve metadata.
465
- (Default: JUMPSTART_DEFAULT_REGION_NAME )
481
+ (Default: None )
466
482
sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use
467
483
to retrieve the model url.
468
484
"""
469
485
486
+ region = region or get_region_fallback (
487
+ sagemaker_session = sagemaker_session ,
488
+ )
470
489
model_specs = verify_model_region_and_return_specs (
471
490
region = region ,
472
491
model_id = model_id ,
0 commit comments