17
17
import logging
18
18
import os
19
19
import re
20
- from typing import Optional
20
+ from typing import Optional , Tuple
21
21
from packaging .version import Version
22
22
23
23
from sagemaker import utils
52
52
53
53
@override_pipeline_parameter_var
54
54
def retrieve (
55
- framework ,
56
- region ,
57
- version = None ,
58
- py_version = None ,
59
- instance_type = None ,
60
- accelerator_type = None ,
61
- image_scope = None ,
62
- container_version = None ,
63
- distribution = None ,
64
- base_framework_version = None ,
65
- training_compiler_config = None ,
66
- model_id = None ,
67
- model_version = None ,
68
- hub_arn = None ,
69
- tolerate_vulnerable_model = False ,
70
- tolerate_deprecated_model = False ,
71
- sdk_version = None ,
72
- inference_tool = None ,
73
- serverless_inference_config = None ,
74
- sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
75
- config_name = None ,
76
- model_type : JumpStartModelType = JumpStartModelType .OPEN_WEIGHTS ,
55
+ framework ,
56
+ region ,
57
+ version = None ,
58
+ py_version = None ,
59
+ instance_type = None ,
60
+ accelerator_type = None ,
61
+ image_scope = None ,
62
+ container_version = None ,
63
+ distribution = None ,
64
+ base_framework_version = None ,
65
+ training_compiler_config = None ,
66
+ model_id = None ,
67
+ model_version = None ,
68
+ hub_arn = None ,
69
+ tolerate_vulnerable_model = False ,
70
+ tolerate_deprecated_model = False ,
71
+ sdk_version = None ,
72
+ inference_tool = None ,
73
+ serverless_inference_config = None ,
74
+ sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
75
+ config_name = None ,
76
+ model_type : JumpStartModelType = JumpStartModelType .OPEN_WEIGHTS ,
77
77
) -> str :
78
78
"""Retrieves the ECR URI for the Docker image matching the given arguments.
79
79
@@ -250,10 +250,10 @@ def retrieve(
250
250
if config .get ("version_aliases" ).get (original_version ):
251
251
_version = config .get ("version_aliases" )[original_version ]
252
252
if (
253
- config .get ("versions" , {})
254
- .get (_version , {})
255
- .get ("version_aliases" , {})
256
- .get (base_framework_version , {})
253
+ config .get ("versions" , {})
254
+ .get (_version , {})
255
+ .get ("version_aliases" , {})
256
+ .get (base_framework_version , {})
257
257
):
258
258
_base_framework_version = config .get ("versions" )[_version ]["version_aliases" ][
259
259
base_framework_version
@@ -290,16 +290,16 @@ def retrieve(
290
290
291
291
292
292
def _get_image_tag (
293
- container_version ,
294
- distribution ,
295
- final_image_scope ,
296
- framework ,
297
- inference_tool ,
298
- instance_type ,
299
- processor ,
300
- py_version ,
301
- tag_prefix ,
302
- version ,
293
+ container_version ,
294
+ distribution ,
295
+ final_image_scope ,
296
+ framework ,
297
+ inference_tool ,
298
+ instance_type ,
299
+ processor ,
300
+ py_version ,
301
+ tag_prefix ,
302
+ version ,
303
303
):
304
304
"""Return image tag based on framework, container, and compute configuration(s)."""
305
305
instance_type_family = utils .get_instance_type_family (instance_type )
@@ -311,8 +311,8 @@ def _get_image_tag(
311
311
"instance type" ,
312
312
)
313
313
if (
314
- instance_type_family in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
315
- or final_image_scope == INFERENCE_GRAVITON
314
+ instance_type_family in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
315
+ or final_image_scope == INFERENCE_GRAVITON
316
316
):
317
317
version_to_arm64_tag_mapping = {
318
318
"xgboost" : {
@@ -330,7 +330,7 @@ def _get_image_tag(
330
330
tag = _format_tag (tag_prefix , processor , py_version , container_version , inference_tool )
331
331
332
332
if instance_type is not None and _should_auto_select_container_version (
333
- instance_type , distribution
333
+ instance_type , distribution
334
334
):
335
335
container_versions = {
336
336
"tensorflow-2.3-gpu-py37" : "cu110-ubuntu18.04-v3" ,
@@ -398,7 +398,7 @@ def _validate_instance_deprecation(framework, instance_type, version):
398
398
"""Check if instance type is deprecated for a certain framework with a certain version"""
399
399
if utils .get_instance_type_family (instance_type ) == "p2" :
400
400
if (framework == "pytorch" and Version (version ) >= Version ("1.13" )) or (
401
- framework == "tensorflow" and Version (version ) >= Version ("2.12" )
401
+ framework == "tensorflow" and Version (version ) >= Version ("2.12" )
402
402
):
403
403
raise ValueError (
404
404
"P2 instances have been deprecated for sagemaker jobs starting PyTorch 1.13 and TensorFlow 2.12"
@@ -411,17 +411,17 @@ def _validate_for_suppported_frameworks_and_instance_type(framework, instance_ty
411
411
"""Validate if framework is supported for the instance_type"""
412
412
# Validate for Trainium allowed frameworks
413
413
if (
414
- instance_type is not None
415
- and "trn" in instance_type
416
- and framework not in TRAINIUM_ALLOWED_FRAMEWORKS
414
+ instance_type is not None
415
+ and "trn" in instance_type
416
+ and framework not in TRAINIUM_ALLOWED_FRAMEWORKS
417
417
):
418
418
_validate_framework (framework , TRAINIUM_ALLOWED_FRAMEWORKS , "framework" , "Trainium" )
419
419
420
420
# Validate for Graviton allowed frameowrks
421
421
if (
422
- instance_type is not None
423
- and utils .get_instance_type_family (instance_type ) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
424
- and framework not in GRAVITON_ALLOWED_FRAMEWORKS
422
+ instance_type is not None
423
+ and utils .get_instance_type_family (instance_type ) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
424
+ and framework not in GRAVITON_ALLOWED_FRAMEWORKS
425
425
):
426
426
_validate_framework (framework , GRAVITON_ALLOWED_FRAMEWORKS , "framework" , "Graviton" )
427
427
@@ -436,8 +436,8 @@ def config_for_framework(framework):
436
436
def _get_final_image_scope (framework , instance_type , image_scope ):
437
437
"""Return final image scope based on provided framework and instance type."""
438
438
if (
439
- framework in GRAVITON_ALLOWED_FRAMEWORKS
440
- and utils .get_instance_type_family (instance_type ) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
439
+ framework in GRAVITON_ALLOWED_FRAMEWORKS
440
+ and utils .get_instance_type_family (instance_type ) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
441
441
):
442
442
return INFERENCE_GRAVITON
443
443
if image_scope is None and framework in (XGBOOST_FRAMEWORK , SKLEARN_FRAMEWORK ):
@@ -635,16 +635,16 @@ def _format_tag(tag_prefix, processor, py_version, container_version, inference_
635
635
636
636
@override_pipeline_parameter_var
637
637
def get_training_image_uri (
638
- region ,
639
- framework ,
640
- framework_version = None ,
641
- py_version = None ,
642
- image_uri = None ,
643
- distribution = None ,
644
- compiler_config = None ,
645
- tensorflow_version = None ,
646
- pytorch_version = None ,
647
- instance_type = None ,
638
+ region ,
639
+ framework ,
640
+ framework_version = None ,
641
+ py_version = None ,
642
+ image_uri = None ,
643
+ distribution = None ,
644
+ compiler_config = None ,
645
+ tensorflow_version = None ,
646
+ pytorch_version = None ,
647
+ instance_type = None ,
648
648
) -> str :
649
649
"""Retrieves the image URI for training.
650
650
@@ -746,3 +746,141 @@ def get_base_python_image_uri(region, py_version="310") -> str:
746
746
repo_and_tag = repo + ":" + version
747
747
748
748
return ECR_URI_TEMPLATE .format (registry = registry , hostname = hostname , repository = repo_and_tag )
749
+
750
+
751
+ def get_latest_container_image (framework : str ,
752
+ image_scope : Optional [str ] = None ,
753
+ instance_type : Optional [str ] = None ,
754
+ py_version : Optional [str ] = None ,
755
+ region : str = "us-west-2" ,
756
+ version : Optional [str ] = None ,
757
+ accelerator_type = None ,
758
+ container_version = None ,
759
+ distribution = None ,
760
+ base_framework_version = None ,
761
+ training_compiler_config = None ,
762
+ model_id = None ,
763
+ model_version = None ,
764
+ hub_arn = None ,
765
+ sdk_version = None ,
766
+ inference_tool = None ,
767
+ serverless_inference_config = None ,
768
+ config_name = None ,
769
+ ) -> Tuple [str , str ]:
770
+ """Retrieves the latest container image URI
771
+ Args:
772
+ framework (str): The name of the framework or algorithm.
773
+ image_scope (str): The image type, i.e. what it is used for.
774
+ Valid values: "training", "inference", "inference_graviton", "eia".
775
+ If ``accelerator_type`` is set, ``image_scope`` is ignored.
776
+ region (str): The AWS region.
777
+ version (str): The framework or algorithm version. This is required if there is
778
+ more than one supported version for the given framework or algorithm.
779
+ py_version (str): The Python version. This is required if there is
780
+ more than one supported Python version for the given framework version.
781
+ instance_type (str): The SageMaker instance type. For supported types, see
782
+ https://aws.amazon.com/sagemaker/pricing. This is required if
783
+ there are different images for different processor types.
784
+ accelerator_type (str): Elastic Inference accelerator type. For more, see
785
+ https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html.
786
+ container_version (str): the version of docker image.
787
+ Ideally the value of parameter should be created inside the framework.
788
+ For custom use, see the list of supported container versions:
789
+ https://github.com/aws/deep-learning-containers/blob/master/available_images.md
790
+ (default: None).
791
+ distribution (dict): A dictionary with information on how to run distributed training
792
+ training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
793
+ A configuration class for the SageMaker Training Compiler
794
+ (default: None).
795
+ model_id (str): The JumpStart model ID for which to retrieve the image URI
796
+ (default: None).
797
+ model_version (str): The version of the JumpStart model for which to retrieve the
798
+ image URI (default: None).
799
+ hub_arn (str): The arn of the SageMaker Hub for which to retrieve
800
+ model details from. (Default: None).
801
+ sdk_version (str): the version of python-sdk that will be used in the image retrieval.
802
+ (default: None).
803
+ inference_tool (str): the tool that will be used to aid in the inference.
804
+ Valid values: "neuron, neuronx, None"
805
+ (default: None).
806
+ serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
807
+ Specifies configuration related to serverless endpoint. Instance type is
808
+ not provided in serverless inference. So this is used to determine processor type.
809
+ config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
810
+ """
811
+ try :
812
+ framework_config = config_for_framework (framework )
813
+ except FileNotFoundError :
814
+ raise ValueError ("Invalid framework {}" .format (framework ))
815
+
816
+ if not framework_config :
817
+ raise ValueError ("Invalid framework {}" .format (framework ))
818
+
819
+ if not version :
820
+ version = _fetch_latest_version_from_config (framework_config , image_scope )
821
+ image_uri = retrieve (framework = framework ,
822
+ region = region ,
823
+ version = version ,
824
+ instance_type = instance_type ,
825
+ py_version = py_version ,
826
+ accelerator_type = accelerator_type ,
827
+ image_scope = image_scope ,
828
+ container_version = container_version ,
829
+ distribution = distribution ,
830
+ base_framework_version = base_framework_version ,
831
+ training_compiler_config = training_compiler_config ,
832
+ model_id = model_id ,
833
+ model_version = model_version ,
834
+ hub_arn = hub_arn ,
835
+ sdk_version = sdk_version ,
836
+ inference_tool = inference_tool ,
837
+ serverless_inference_config = serverless_inference_config ,
838
+ config_name = config_name
839
+ )
840
+ return image_uri , version
841
+
842
+
843
+ def _fetch_latest_version_from_config (framework_config : dict ,
844
+ image_scope : Optional [str ] = None ) -> Optional [str ]:
845
+ """ Helper function to fetch the latest version as a string from a framework's config
846
+ Args:
847
+ framework_config (dict): A framework config dict.
848
+ image_scope (str): Scope of the image, eg: training, inference
849
+ Returns:
850
+ Version string if latest version found else None
851
+ """
852
+ if image_scope in framework_config :
853
+ if image_scope_config := framework_config [image_scope ]:
854
+ if "version_aliases" in image_scope_config :
855
+ if "latest" in image_scope_config ["version_aliases" ]:
856
+ return image_scope_config ["version_aliases" ]["latest" ]
857
+ top_version = None
858
+ bottom_version = None
859
+
860
+ if "versions" in framework_config :
861
+ versions = list (framework_config ["versions" ].keys ())
862
+ top_version = versions [0 ]
863
+ bottom_version = versions [- 1 ]
864
+ if top_version == "latest" or bottom_version == "latest" :
865
+ return None
866
+ elif (image_scope is not None and image_scope in framework_config
867
+ and "versions" in framework_config [image_scope ]):
868
+ versions = list (framework_config [image_scope ]["versions" ].keys ())
869
+ top_version = versions [0 ]
870
+ bottom_version = versions [- 1 ]
871
+ elif "processing" in framework_config and "versions" in framework_config ["processing" ]:
872
+ versions = list (framework_config ["processing" ]["versions" ].keys ())
873
+ top_version = versions [0 ]
874
+ bottom_version = versions [- 1 ]
875
+
876
+ if top_version and bottom_version :
877
+ if top_version .endswith (".x" ) or bottom_version .endswith (".x" ):
878
+ top_number = int (top_version [:- 2 ])
879
+ bottom_number = int (bottom_version [:- 2 ])
880
+ max_version = max (top_number , bottom_number )
881
+ return f"{ max_version } .x"
882
+ if Version (top_version ) >= Version (bottom_version ):
883
+ return top_version
884
+ return bottom_version
885
+
886
+ return None
0 commit comments