Skip to content

Commit 1696df3

Browse files
nargokulpintaoz-aws
authored andcommitted
Latest Container Image (#1545)
* Latest Container Image * Test Fixes * Parameterized tests and some logic updates * Test fixes * Move to Image URI * Fixes for unit test * Fixes for unit test * Fix codestyle error checks
1 parent 712a23a commit 1696df3

File tree

5 files changed

+373
-59
lines changed

5 files changed

+373
-59
lines changed

src/sagemaker/image_uris.py

+197-59
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import logging
1818
import os
1919
import re
20-
from typing import Optional
20+
from typing import Optional, Tuple
2121
from packaging.version import Version
2222

2323
from sagemaker import utils
@@ -52,28 +52,28 @@
5252

5353
@override_pipeline_parameter_var
5454
def retrieve(
55-
framework,
56-
region,
57-
version=None,
58-
py_version=None,
59-
instance_type=None,
60-
accelerator_type=None,
61-
image_scope=None,
62-
container_version=None,
63-
distribution=None,
64-
base_framework_version=None,
65-
training_compiler_config=None,
66-
model_id=None,
67-
model_version=None,
68-
hub_arn=None,
69-
tolerate_vulnerable_model=False,
70-
tolerate_deprecated_model=False,
71-
sdk_version=None,
72-
inference_tool=None,
73-
serverless_inference_config=None,
74-
sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
75-
config_name=None,
76-
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
55+
framework,
56+
region,
57+
version=None,
58+
py_version=None,
59+
instance_type=None,
60+
accelerator_type=None,
61+
image_scope=None,
62+
container_version=None,
63+
distribution=None,
64+
base_framework_version=None,
65+
training_compiler_config=None,
66+
model_id=None,
67+
model_version=None,
68+
hub_arn=None,
69+
tolerate_vulnerable_model=False,
70+
tolerate_deprecated_model=False,
71+
sdk_version=None,
72+
inference_tool=None,
73+
serverless_inference_config=None,
74+
sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
75+
config_name=None,
76+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
7777
) -> str:
7878
"""Retrieves the ECR URI for the Docker image matching the given arguments.
7979
@@ -250,10 +250,10 @@ def retrieve(
250250
if config.get("version_aliases").get(original_version):
251251
_version = config.get("version_aliases")[original_version]
252252
if (
253-
config.get("versions", {})
254-
.get(_version, {})
255-
.get("version_aliases", {})
256-
.get(base_framework_version, {})
253+
config.get("versions", {})
254+
.get(_version, {})
255+
.get("version_aliases", {})
256+
.get(base_framework_version, {})
257257
):
258258
_base_framework_version = config.get("versions")[_version]["version_aliases"][
259259
base_framework_version
@@ -290,16 +290,16 @@ def retrieve(
290290

291291

292292
def _get_image_tag(
293-
container_version,
294-
distribution,
295-
final_image_scope,
296-
framework,
297-
inference_tool,
298-
instance_type,
299-
processor,
300-
py_version,
301-
tag_prefix,
302-
version,
293+
container_version,
294+
distribution,
295+
final_image_scope,
296+
framework,
297+
inference_tool,
298+
instance_type,
299+
processor,
300+
py_version,
301+
tag_prefix,
302+
version,
303303
):
304304
"""Return image tag based on framework, container, and compute configuration(s)."""
305305
instance_type_family = utils.get_instance_type_family(instance_type)
@@ -311,8 +311,8 @@ def _get_image_tag(
311311
"instance type",
312312
)
313313
if (
314-
instance_type_family in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
315-
or final_image_scope == INFERENCE_GRAVITON
314+
instance_type_family in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
315+
or final_image_scope == INFERENCE_GRAVITON
316316
):
317317
version_to_arm64_tag_mapping = {
318318
"xgboost": {
@@ -330,7 +330,7 @@ def _get_image_tag(
330330
tag = _format_tag(tag_prefix, processor, py_version, container_version, inference_tool)
331331

332332
if instance_type is not None and _should_auto_select_container_version(
333-
instance_type, distribution
333+
instance_type, distribution
334334
):
335335
container_versions = {
336336
"tensorflow-2.3-gpu-py37": "cu110-ubuntu18.04-v3",
@@ -398,7 +398,7 @@ def _validate_instance_deprecation(framework, instance_type, version):
398398
"""Check if instance type is deprecated for a certain framework with a certain version"""
399399
if utils.get_instance_type_family(instance_type) == "p2":
400400
if (framework == "pytorch" and Version(version) >= Version("1.13")) or (
401-
framework == "tensorflow" and Version(version) >= Version("2.12")
401+
framework == "tensorflow" and Version(version) >= Version("2.12")
402402
):
403403
raise ValueError(
404404
"P2 instances have been deprecated for sagemaker jobs starting PyTorch 1.13 and TensorFlow 2.12"
@@ -411,17 +411,17 @@ def _validate_for_suppported_frameworks_and_instance_type(framework, instance_ty
411411
"""Validate if framework is supported for the instance_type"""
412412
# Validate for Trainium allowed frameworks
413413
if (
414-
instance_type is not None
415-
and "trn" in instance_type
416-
and framework not in TRAINIUM_ALLOWED_FRAMEWORKS
414+
instance_type is not None
415+
and "trn" in instance_type
416+
and framework not in TRAINIUM_ALLOWED_FRAMEWORKS
417417
):
418418
_validate_framework(framework, TRAINIUM_ALLOWED_FRAMEWORKS, "framework", "Trainium")
419419

420420
# Validate for Graviton allowed frameowrks
421421
if (
422-
instance_type is not None
423-
and utils.get_instance_type_family(instance_type) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
424-
and framework not in GRAVITON_ALLOWED_FRAMEWORKS
422+
instance_type is not None
423+
and utils.get_instance_type_family(instance_type) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
424+
and framework not in GRAVITON_ALLOWED_FRAMEWORKS
425425
):
426426
_validate_framework(framework, GRAVITON_ALLOWED_FRAMEWORKS, "framework", "Graviton")
427427

@@ -436,8 +436,8 @@ def config_for_framework(framework):
436436
def _get_final_image_scope(framework, instance_type, image_scope):
437437
"""Return final image scope based on provided framework and instance type."""
438438
if (
439-
framework in GRAVITON_ALLOWED_FRAMEWORKS
440-
and utils.get_instance_type_family(instance_type) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
439+
framework in GRAVITON_ALLOWED_FRAMEWORKS
440+
and utils.get_instance_type_family(instance_type) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
441441
):
442442
return INFERENCE_GRAVITON
443443
if image_scope is None and framework in (XGBOOST_FRAMEWORK, SKLEARN_FRAMEWORK):
@@ -635,16 +635,16 @@ def _format_tag(tag_prefix, processor, py_version, container_version, inference_
635635

636636
@override_pipeline_parameter_var
637637
def get_training_image_uri(
638-
region,
639-
framework,
640-
framework_version=None,
641-
py_version=None,
642-
image_uri=None,
643-
distribution=None,
644-
compiler_config=None,
645-
tensorflow_version=None,
646-
pytorch_version=None,
647-
instance_type=None,
638+
region,
639+
framework,
640+
framework_version=None,
641+
py_version=None,
642+
image_uri=None,
643+
distribution=None,
644+
compiler_config=None,
645+
tensorflow_version=None,
646+
pytorch_version=None,
647+
instance_type=None,
648648
) -> str:
649649
"""Retrieves the image URI for training.
650650
@@ -746,3 +746,141 @@ def get_base_python_image_uri(region, py_version="310") -> str:
746746
repo_and_tag = repo + ":" + version
747747

748748
return ECR_URI_TEMPLATE.format(registry=registry, hostname=hostname, repository=repo_and_tag)
749+
750+
751+
def get_latest_container_image(framework: str,
752+
image_scope: Optional[str] = None,
753+
instance_type: Optional[str] = None,
754+
py_version: Optional[str] = None,
755+
region: str = "us-west-2",
756+
version: Optional[str] = None,
757+
accelerator_type=None,
758+
container_version=None,
759+
distribution=None,
760+
base_framework_version=None,
761+
training_compiler_config=None,
762+
model_id=None,
763+
model_version=None,
764+
hub_arn=None,
765+
sdk_version=None,
766+
inference_tool=None,
767+
serverless_inference_config=None,
768+
config_name=None,
769+
) -> Tuple[str, str]:
770+
"""Retrieves the latest container image URI
771+
Args:
772+
framework (str): The name of the framework or algorithm.
773+
image_scope (str): The image type, i.e. what it is used for.
774+
Valid values: "training", "inference", "inference_graviton", "eia".
775+
If ``accelerator_type`` is set, ``image_scope`` is ignored.
776+
region (str): The AWS region.
777+
version (str): The framework or algorithm version. This is required if there is
778+
more than one supported version for the given framework or algorithm.
779+
py_version (str): The Python version. This is required if there is
780+
more than one supported Python version for the given framework version.
781+
instance_type (str): The SageMaker instance type. For supported types, see
782+
https://aws.amazon.com/sagemaker/pricing. This is required if
783+
there are different images for different processor types.
784+
accelerator_type (str): Elastic Inference accelerator type. For more, see
785+
https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html.
786+
container_version (str): the version of docker image.
787+
Ideally the value of parameter should be created inside the framework.
788+
For custom use, see the list of supported container versions:
789+
https://github.com/aws/deep-learning-containers/blob/master/available_images.md
790+
(default: None).
791+
distribution (dict): A dictionary with information on how to run distributed training
792+
training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
793+
A configuration class for the SageMaker Training Compiler
794+
(default: None).
795+
model_id (str): The JumpStart model ID for which to retrieve the image URI
796+
(default: None).
797+
model_version (str): The version of the JumpStart model for which to retrieve the
798+
image URI (default: None).
799+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
800+
model details from. (Default: None).
801+
sdk_version (str): the version of python-sdk that will be used in the image retrieval.
802+
(default: None).
803+
inference_tool (str): the tool that will be used to aid in the inference.
804+
Valid values: "neuron, neuronx, None"
805+
(default: None).
806+
serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
807+
Specifies configuration related to serverless endpoint. Instance type is
808+
not provided in serverless inference. So this is used to determine processor type.
809+
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
810+
"""
811+
try:
812+
framework_config = config_for_framework(framework)
813+
except FileNotFoundError:
814+
raise ValueError("Invalid framework {}".format(framework))
815+
816+
if not framework_config:
817+
raise ValueError("Invalid framework {}".format(framework))
818+
819+
if not version:
820+
version = _fetch_latest_version_from_config(framework_config, image_scope)
821+
image_uri = retrieve(framework=framework,
822+
region=region,
823+
version=version,
824+
instance_type=instance_type,
825+
py_version=py_version,
826+
accelerator_type=accelerator_type,
827+
image_scope=image_scope,
828+
container_version=container_version,
829+
distribution=distribution,
830+
base_framework_version=base_framework_version,
831+
training_compiler_config=training_compiler_config,
832+
model_id=model_id,
833+
model_version=model_version,
834+
hub_arn=hub_arn,
835+
sdk_version=sdk_version,
836+
inference_tool=inference_tool,
837+
serverless_inference_config=serverless_inference_config,
838+
config_name=config_name
839+
)
840+
return image_uri, version
841+
842+
843+
def _fetch_latest_version_from_config(framework_config: dict,
844+
image_scope: Optional[str] = None) -> Optional[str]:
845+
""" Helper function to fetch the latest version as a string from a framework's config
846+
Args:
847+
framework_config (dict): A framework config dict.
848+
image_scope (str): Scope of the image, eg: training, inference
849+
Returns:
850+
Version string if latest version found else None
851+
"""
852+
if image_scope in framework_config:
853+
if image_scope_config := framework_config[image_scope]:
854+
if "version_aliases" in image_scope_config:
855+
if "latest" in image_scope_config["version_aliases"]:
856+
return image_scope_config["version_aliases"]["latest"]
857+
top_version = None
858+
bottom_version = None
859+
860+
if "versions" in framework_config:
861+
versions = list(framework_config["versions"].keys())
862+
top_version = versions[0]
863+
bottom_version = versions[-1]
864+
if top_version == "latest" or bottom_version == "latest":
865+
return None
866+
elif (image_scope is not None and image_scope in framework_config
867+
and "versions" in framework_config[image_scope]):
868+
versions = list(framework_config[image_scope]["versions"].keys())
869+
top_version = versions[0]
870+
bottom_version = versions[-1]
871+
elif "processing" in framework_config and "versions" in framework_config["processing"]:
872+
versions = list(framework_config["processing"]["versions"].keys())
873+
top_version = versions[0]
874+
bottom_version = versions[-1]
875+
876+
if top_version and bottom_version:
877+
if top_version.endswith(".x") or bottom_version.endswith(".x"):
878+
top_number = int(top_version[:-2])
879+
bottom_number = int(bottom_version[:-2])
880+
max_version = max(top_number, bottom_number)
881+
return f"{max_version}.x"
882+
if Version(top_version) >= Version(bottom_version):
883+
return top_version
884+
return bottom_version
885+
886+
return None
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"mpi_additional_options": ["-x", "MASTER_ADDR=algo-1", "-x", "MASTER_PORT=7777"], "_type": "mpi"}

0 commit comments

Comments
 (0)