@@ -167,20 +167,13 @@ def retrieve(
167
167
)
168
168
else :
169
169
_framework = framework
170
- if (
171
- framework == HUGGING_FACE_FRAMEWORK
172
- or framework in TRAINIUM_ALLOWED_FRAMEWORKS
173
- ):
170
+ if framework == HUGGING_FACE_FRAMEWORK or framework in TRAINIUM_ALLOWED_FRAMEWORKS :
174
171
inference_tool = _get_inference_tool (inference_tool , instance_type )
175
172
if inference_tool in ["neuron" , "neuronx" ]:
176
173
_framework = f"{ framework } -{ inference_tool } "
177
- final_image_scope = _get_final_image_scope (
178
- framework , instance_type , image_scope
179
- )
174
+ final_image_scope = _get_final_image_scope (framework , instance_type , image_scope )
180
175
_validate_for_suppported_frameworks_and_instance_type (framework , instance_type )
181
- config = _config_for_framework_and_scope (
182
- _framework , final_image_scope , accelerator_type
183
- )
176
+ config = _config_for_framework_and_scope (_framework , final_image_scope , accelerator_type )
184
177
185
178
original_version = version
186
179
version = _validate_version_and_set_if_needed (version , config , framework )
@@ -191,14 +184,10 @@ def retrieve(
191
184
full_base_framework_version = version_config ["version_aliases" ].get (
192
185
base_framework_version , base_framework_version
193
186
)
194
- _validate_arg (
195
- full_base_framework_version , list (version_config .keys ()), "base framework"
196
- )
187
+ _validate_arg (full_base_framework_version , list (version_config .keys ()), "base framework" )
197
188
version_config = version_config .get (full_base_framework_version )
198
189
199
- py_version = _validate_py_version_and_set_if_needed (
200
- py_version , version_config , framework
201
- )
190
+ py_version = _validate_py_version_and_set_if_needed (py_version , version_config , framework )
202
191
version_config = version_config .get (py_version ) or version_config
203
192
registry = _registry_from_region (region , version_config ["registries" ])
204
193
endpoint_data = utils ._botocore_resolver ().construct_endpoint ("ecr" , region )
@@ -226,9 +215,7 @@ def retrieve(
226
215
227
216
if framework == HUGGING_FACE_FRAMEWORK :
228
217
pt_or_tf_version = (
229
- re .compile ("^(pytorch|tensorflow)(.*)$" )
230
- .match (base_framework_version )
231
- .group (2 )
218
+ re .compile ("^(pytorch|tensorflow)(.*)$" ).match (base_framework_version ).group (2 )
232
219
)
233
220
_version = original_version
234
221
@@ -252,13 +239,11 @@ def retrieve(
252
239
.get ("version_aliases" , {})
253
240
.get (base_framework_version , {})
254
241
):
255
- _base_framework_version = config .get ("versions" )[_version ][
256
- "version_aliases"
257
- ][ base_framework_version ]
242
+ _base_framework_version = config .get ("versions" )[_version ]["version_aliases" ][
243
+ base_framework_version
244
+ ]
258
245
pt_or_tf_version = (
259
- re .compile ("^(pytorch|tensorflow)(.*)$" )
260
- .match (_base_framework_version )
261
- .group (2 )
246
+ re .compile ("^(pytorch|tensorflow)(.*)$" ).match (_base_framework_version ).group (2 )
262
247
)
263
248
264
249
tag_prefix = f"{ pt_or_tf_version } -transformers{ _version } "
@@ -285,9 +270,7 @@ def retrieve(
285
270
if tag :
286
271
repo += ":{}" .format (tag )
287
272
288
- return ECR_URI_TEMPLATE .format (
289
- registry = registry , hostname = hostname , repository = repo
290
- )
273
+ return ECR_URI_TEMPLATE .format (registry = registry , hostname = hostname , repository = repo )
291
274
292
275
293
276
def _get_image_tag (
@@ -326,13 +309,9 @@ def _get_image_tag(
326
309
}
327
310
tag = version_to_arm64_tag_mapping [framework ][version ]
328
311
else :
329
- tag = _format_tag (
330
- tag_prefix , processor , py_version , container_version , inference_tool
331
- )
312
+ tag = _format_tag (tag_prefix , processor , py_version , container_version , inference_tool )
332
313
else :
333
- tag = _format_tag (
334
- tag_prefix , processor , py_version , container_version , inference_tool
335
- )
314
+ tag = _format_tag (tag_prefix , processor , py_version , container_version , inference_tool )
336
315
337
316
if instance_type is not None and _should_auto_select_container_version (
338
317
instance_type , distribution
@@ -383,11 +362,7 @@ def _config_for_framework_and_scope(framework, image_scope, accelerator_type=Non
383
362
)
384
363
image_scope = available_scopes [0 ]
385
364
386
- if (
387
- not image_scope
388
- and "scope" in config
389
- and set (available_scopes ) == {"training" , "inference" }
390
- ):
365
+ if not image_scope and "scope" in config and set (available_scopes ) == {"training" , "inference" }:
391
366
logger .info (
392
367
"Same images used for training and inference. Defaulting to image scope: %s." ,
393
368
available_scopes [0 ],
@@ -419,27 +394,20 @@ def _validate_for_suppported_frameworks_and_instance_type(framework, instance_ty
419
394
and "trn" in instance_type
420
395
and framework not in TRAINIUM_ALLOWED_FRAMEWORKS
421
396
):
422
- _validate_framework (
423
- framework , TRAINIUM_ALLOWED_FRAMEWORKS , "framework" , "Trainium"
424
- )
397
+ _validate_framework (framework , TRAINIUM_ALLOWED_FRAMEWORKS , "framework" , "Trainium" )
425
398
426
399
# Validate for Graviton allowed frameowrks
427
400
if (
428
401
instance_type is not None
429
- and utils .get_instance_type_family (instance_type )
430
- in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
402
+ and utils .get_instance_type_family (instance_type ) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
431
403
and framework not in GRAVITON_ALLOWED_FRAMEWORKS
432
404
):
433
- _validate_framework (
434
- framework , GRAVITON_ALLOWED_FRAMEWORKS , "framework" , "Graviton"
435
- )
405
+ _validate_framework (framework , GRAVITON_ALLOWED_FRAMEWORKS , "framework" , "Graviton" )
436
406
437
407
438
408
def config_for_framework (framework ):
439
409
"""Loads the JSON config for the given framework."""
440
- fname = os .path .join (
441
- os .path .dirname (__file__ ), "image_uri_config" , "{}.json" .format (framework )
442
- )
410
+ fname = os .path .join (os .path .dirname (__file__ ), "image_uri_config" , "{}.json" .format (framework ))
443
411
with open (fname ) as f :
444
412
return json .load (f )
445
413
@@ -448,8 +416,7 @@ def _get_final_image_scope(framework, instance_type, image_scope):
448
416
"""Return final image scope based on provided framework and instance type."""
449
417
if (
450
418
framework in GRAVITON_ALLOWED_FRAMEWORKS
451
- and utils .get_instance_type_family (instance_type )
452
- in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
419
+ and utils .get_instance_type_family (instance_type ) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
453
420
):
454
421
return INFERENCE_GRAVITON
455
422
if image_scope is None and framework in (XGBOOST_FRAMEWORK , SKLEARN_FRAMEWORK ):
@@ -465,9 +432,7 @@ def _get_inference_tool(inference_tool, instance_type):
465
432
"""Extract the inference tool name from instance type."""
466
433
if not inference_tool :
467
434
instance_type_family = utils .get_instance_type_family (instance_type )
468
- if instance_type_family .startswith ("inf" ) or instance_type_family .startswith (
469
- "trn"
470
- ):
435
+ if instance_type_family .startswith ("inf" ) or instance_type_family .startswith ("trn" ):
471
436
return "neuron"
472
437
return inference_tool
473
438
@@ -479,15 +444,10 @@ def _get_latest_versions(list_of_versions):
479
444
480
445
def _validate_accelerator_type (accelerator_type ):
481
446
"""Raises a ``ValueError`` if ``accelerator_type`` is invalid."""
482
- if (
483
- not accelerator_type .startswith ("ml.eia" )
484
- and accelerator_type != "local_sagemaker_notebook"
485
- ):
447
+ if not accelerator_type .startswith ("ml.eia" ) and accelerator_type != "local_sagemaker_notebook" :
486
448
raise ValueError (
487
449
"Invalid SageMaker Elastic Inference accelerator type: {}. "
488
- "See https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html" .format (
489
- accelerator_type
490
- )
450
+ "See https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html" .format (accelerator_type )
491
451
)
492
452
493
453
@@ -497,15 +457,11 @@ def _validate_version_and_set_if_needed(version, config, framework):
497
457
aliased_versions = list (config .get ("version_aliases" , {}).keys ())
498
458
499
459
if len (available_versions ) == 1 and version not in aliased_versions :
500
- log_message = (
501
- "Defaulting to the only supported framework/algorithm version: {}." .format (
502
- available_versions [0 ]
503
- )
460
+ log_message = "Defaulting to the only supported framework/algorithm version: {}." .format (
461
+ available_versions [0 ]
504
462
)
505
463
if version and version != available_versions [0 ]:
506
- logger .warning (
507
- "%s Ignoring framework/algorithm version: %s." , log_message , version
508
- )
464
+ logger .warning ("%s Ignoring framework/algorithm version: %s." , log_message , version )
509
465
elif not version :
510
466
logger .info (log_message )
511
467
@@ -518,9 +474,7 @@ def _validate_version_and_set_if_needed(version, config, framework):
518
474
]:
519
475
version = _get_latest_versions (available_versions )
520
476
521
- _validate_arg (
522
- version , available_versions + aliased_versions , "{} version" .format (framework )
523
- )
477
+ _validate_arg (version , available_versions + aliased_versions , "{} version" .format (framework ))
524
478
return version
525
479
526
480
@@ -546,9 +500,7 @@ def _processor(instance_type, available_processors, serverless_inference_config=
546
500
return None
547
501
548
502
if len (available_processors ) == 1 and not instance_type :
549
- logger .info (
550
- "Defaulting to only supported image scope: %s." , available_processors [0 ]
551
- )
503
+ logger .info ("Defaulting to only supported image scope: %s." , available_processors [0 ])
552
504
return available_processors [0 ]
553
505
554
506
if serverless_inference_config is not None :
@@ -585,9 +537,7 @@ def _processor(instance_type, available_processors, serverless_inference_config=
585
537
else :
586
538
raise ValueError (
587
539
"Invalid SageMaker instance type: {}. For options, see: "
588
- "https://aws.amazon.com/sagemaker/pricing/instance-types" .format (
589
- instance_type
590
- )
540
+ "https://aws.amazon.com/sagemaker/pricing/instance-types" .format (instance_type )
591
541
)
592
542
593
543
_validate_arg (processor , available_processors , "processor" )
@@ -626,9 +576,7 @@ def _validate_py_version_and_set_if_needed(py_version, version_config, framework
626
576
return None
627
577
628
578
if py_version is None and len (available_versions ) == 1 :
629
- logger .info (
630
- "Defaulting to only available Python version: %s" , available_versions [0 ]
631
- )
579
+ logger .info ("Defaulting to only available Python version: %s" , available_versions [0 ])
632
580
return available_versions [0 ]
633
581
634
582
_validate_arg (py_version , available_versions , "Python version" )
@@ -641,9 +589,7 @@ def _validate_arg(arg, available_options, arg_name):
641
589
raise ValueError (
642
590
"Unsupported {arg_name}: {arg}. You may need to upgrade your SDK version "
643
591
"(pip install -U sagemaker) for newer {arg_name}s. Supported {arg_name}(s): "
644
- "{options}." .format (
645
- arg_name = arg_name , arg = arg , options = ", " .join (available_options )
646
- )
592
+ "{options}." .format (arg_name = arg_name , arg = arg , options = ", " .join (available_options ))
647
593
)
648
594
649
595
@@ -656,17 +602,11 @@ def _validate_framework(framework, allowed_frameworks, arg_name, hardware_name):
656
602
)
657
603
658
604
659
- def _format_tag (
660
- tag_prefix , processor , py_version , container_version , inference_tool = None
661
- ):
605
+ def _format_tag (tag_prefix , processor , py_version , container_version , inference_tool = None ):
662
606
"""Creates a tag for the image URI."""
663
607
if inference_tool :
664
- return "-" .join (
665
- x for x in (tag_prefix , inference_tool , py_version , container_version ) if x
666
- )
667
- return "-" .join (
668
- x for x in (tag_prefix , processor , py_version , container_version ) if x
669
- )
608
+ return "-" .join (x for x in (tag_prefix , inference_tool , py_version , container_version ) if x )
609
+ return "-" .join (x for x in (tag_prefix , processor , py_version , container_version ) if x )
670
610
671
611
672
612
@override_pipeline_parameter_var
@@ -775,6 +715,4 @@ def get_base_python_image_uri(region, py_version="310") -> str:
775
715
repo = version_config ["repository" ] + "-" + py_version
776
716
repo_and_tag = repo + ":" + version
777
717
778
- return ECR_URI_TEMPLATE .format (
779
- registry = registry , hostname = hostname , repository = repo_and_tag
780
- )
718
+ return ECR_URI_TEMPLATE .format (registry = registry , hostname = hostname , repository = repo_and_tag )
0 commit comments