diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index c05511ef70..9834d4eb65 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -54,6 +54,56 @@ VALID_EIA_FRAMEWORKS = ["tensorflow", "tensorflow-serving", "mxnet", "mxnet-serving"] VALID_ACCOUNTS_BY_REGION = {"us-gov-west-1": "246785580436", "us-iso-east-1": "744548109606"} +MERGED_FRAMEWORKS_REPO_MAP = { + "tensorflow-scriptmode": "tensorflow-training", + "mxnet": "mxnet-training", + "tensorflow-serving": "tensorflow-inference", + "mxnet-serving": "mxnet-inference", +} + +MERGED_FRAMEWORKS_LOWEST_VERSIONS = { + "tensorflow-scriptmode": [1, 13, 1], + "mxnet": [1, 4, 1], + "tensorflow-serving": [1, 13, 0], + "mxnet-serving": [1, 4, 1], +} + + +def is_version_equal_or_higher(lowest_version, framework_version): + """Determine whether the ``framework_version`` is equal to or higher than ``lowest_version`` + + Args: + lowest_version (List[int]): lowest version represented in an integer list + framework_version (str): framework version string + + Returns: + bool: Whether or not framework_version is equal to or higher than lowest_version + """ + version_list = [int(s) for s in framework_version.split(".")] + return version_list >= lowest_version[0 : len(version_list)] + + +def _is_merged_versions(framework, framework_version): + lowest_version_list = MERGED_FRAMEWORKS_LOWEST_VERSIONS.get(framework) + if lowest_version_list: + return is_version_equal_or_higher(lowest_version_list, framework_version) + else: + return False + + +def _using_merged_images(region, framework, py_version, accelerator_type, framework_version): + is_gov_region = region in VALID_ACCOUNTS_BY_REGION + is_py3 = py_version == "py3" or py_version is None + is_merged_versions = _is_merged_versions(framework, framework_version) + return (not is_gov_region) and is_merged_versions and is_py3 and accelerator_type is None + + +def _registry_id(region, framework, py_version, account, accelerator_type, framework_version): + if _using_merged_images(region, framework, py_version, accelerator_type, framework_version): + return "763104351884" + else: + return VALID_ACCOUNTS_BY_REGION.get(region, account) + def create_image_uri( region, @@ -86,8 +136,15 @@ def create_image_uri( if py_version and py_version not in VALID_PY_VERSIONS: raise ValueError("invalid py_version argument: {}".format(py_version)) - # Handle Account Number for Gov Cloud - account = VALID_ACCOUNTS_BY_REGION.get(region, account) + # Handle Account Number for Gov Cloud and frameworks with DLC merged images + account = _registry_id( + region=region, + framework=framework, + py_version=py_version, + account=account, + accelerator_type=accelerator_type, + framework_version=framework_version, + ) # Handle Local Mode if instance_type.startswith("local"): @@ -121,7 +178,14 @@ def create_image_uri( ): framework += "-eia" - return "{}/sagemaker-{}:{}".format(get_ecr_image_uri_prefix(account, region), framework, tag) + if _using_merged_images(region, framework, py_version, accelerator_type, framework_version): + return "{}/{}:{}".format( + get_ecr_image_uri_prefix(account, region), MERGED_FRAMEWORKS_REPO_MAP[framework], tag + ) + else: + return "{}/sagemaker-{}:{}".format( + get_ecr_image_uri_prefix(account, region), framework, tag + ) def _accelerator_type_valid_for_framework( @@ -264,7 +328,7 @@ def framework_name_from_image(image_name): # extract framework, python version and image tag # We must support both the legacy and current image name format. name_pattern = re.compile( - r"^sagemaker(?:-rl)?-(tensorflow|mxnet|chainer|pytorch|scikit-learn)(?:-)?(scriptmode)?:(.*)-(.*?)-(py2|py3)$" # noqa: E501 + r"^(?:sagemaker(?:-rl)?-)?(tensorflow|mxnet|chainer|pytorch|scikit-learn)(?:-)?(scriptmode|training)?:(.*)-(.*?)-(py2|py3)$" # noqa: E501 ) legacy_name_pattern = re.compile(r"^sagemaker-(tensorflow|mxnet)-(py2|py3)-(cpu|gpu):(.*)$") diff --git a/tests/integ/test_tf_script_mode.py b/tests/integ/test_tf_script_mode.py index 49255fd9a3..9c75575bc2 100644 --- a/tests/integ/test_tf_script_mode.py +++ b/tests/integ/test_tf_script_mode.py @@ -65,6 +65,7 @@ def test_mnist(sagemaker_session, instance_type): sagemaker_session=sagemaker_session, script_mode=True, framework_version=TensorFlow.LATEST_VERSION, + py_version=tests.integ.PYTHON_VERSION, metric_definitions=[{"Name": "train:global_steps", "Regex": r"global_step\/sec:\s(.*)"}], ) inputs = estimator.sagemaker_session.upload_data( @@ -98,6 +99,7 @@ def test_server_side_encryption(sagemaker_session): sagemaker_session=sagemaker_session, script_mode=True, framework_version=TensorFlow.LATEST_VERSION, + py_version=tests.integ.PYTHON_VERSION, code_location=output_path, output_path=output_path, model_dir="/opt/ml/model", @@ -144,6 +146,7 @@ def test_mnist_async(sagemaker_session): role=ROLE, train_instance_count=1, train_instance_type="ml.c5.4xlarge", + py_version=tests.integ.PYTHON_VERSION, sagemaker_session=sagemaker_session, script_mode=True, framework_version=TensorFlow.LATEST_VERSION, @@ -182,6 +185,7 @@ def test_deploy_with_input_handlers(sagemaker_session, instance_type): role=ROLE, train_instance_count=1, train_instance_type=instance_type, + py_version=tests.integ.PYTHON_VERSION, sagemaker_session=sagemaker_session, script_mode=True, framework_version=TensorFlow.LATEST_VERSION, diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index 05aca78e38..5d1a64d231 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -136,18 +136,59 @@ def test_create_image_uri_gov_cloud(): ) +def test_create_image_uri_merged(): + image_uri = fw_utils.create_image_uri( + "us-west-2", "tensorflow-scriptmode", "ml.p3.2xlarge", "1.13.1", "py3" + ) + assert ( + image_uri + == "763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-training:1.13.1-gpu-py3" + ) + + image_uri = fw_utils.create_image_uri( + "us-west-2", "tensorflow-serving", "ml.c4.2xlarge", "1.13.1" + ) + assert ( + image_uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-inference:1.13.1-cpu" + ) + + image_uri = fw_utils.create_image_uri("us-west-2", "mxnet", "ml.p3.2xlarge", "1.4.1", "py3") + assert image_uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/mxnet-training:1.4.1-gpu-py3" + + image_uri = fw_utils.create_image_uri( + "us-west-2", "mxnet-serving", "ml.c4.2xlarge", "1.4.1", "py3" + ) + assert image_uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/mxnet-inference:1.4.1-cpu-py3" + + +def test_create_image_uri_merged_py2(): + image_uri = fw_utils.create_image_uri( + "us-west-2", "tensorflow-scriptmode", "ml.p3.2xlarge", "1.13.1", "py2" + ) + assert ( + image_uri + == "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow-scriptmode:1.13.1-gpu-py2" + ) + + image_uri = fw_utils.create_image_uri("us-west-2", "mxnet", "ml.p3.2xlarge", "1.4.1", "py2") + assert image_uri == "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:1.4.1-gpu-py2" + + image_uri = fw_utils.create_image_uri( + "us-west-2", "mxnet-serving", "ml.c4.2xlarge", "1.4.1", "py2" + ) + assert ( + image_uri + == "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet-serving:1.4.1-cpu-py2" + ) + + def test_create_image_uri_accelerator_tf(): image_uri = fw_utils.create_image_uri( - MOCK_REGION, - "tensorflow", - "ml.p3.2xlarge", - "1.0rc", - "py3", - accelerator_type="ml.eia1.medium", + MOCK_REGION, "tensorflow", "ml.p3.2xlarge", "1.0", "py3", accelerator_type="ml.eia1.medium" ) assert ( image_uri - == "520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-tensorflow-eia:1.0rc-gpu-py3" + == "520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-tensorflow-eia:1.0-gpu-py3" ) @@ -156,13 +197,13 @@ def test_create_image_uri_accelerator_mxnet_serving(): MOCK_REGION, "mxnet-serving", "ml.p3.2xlarge", - "1.0rc", + "1.0", "py3", accelerator_type="ml.eia1.medium", ) assert ( image_uri - == "520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mxnet-serving-eia:1.0rc-gpu-py3" + == "520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mxnet-serving-eia:1.0-gpu-py3" ) @@ -171,13 +212,13 @@ def test_create_image_uri_local_sagemaker_notebook_accelerator(): MOCK_REGION, "mxnet", "ml.p3.2xlarge", - "1.0rc", + "1.0", "py3", accelerator_type="local_sagemaker_notebook", ) assert ( image_uri - == "520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mxnet-eia:1.0rc-gpu-py3" + == "520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mxnet-eia:1.0-gpu-py3" ) @@ -555,6 +596,11 @@ def test_framework_name_from_image_tf_scriptmode(): "scriptmode", ) == fw_utils.framework_name_from_image(image_name) + image_name = "123.dkr.ecr.us-west-2.amazonaws.com/tensorflow-training:1.13-cpu-py3" + assert ("tensorflow", "py3", "1.13-cpu-py3", "training") == fw_utils.framework_name_from_image( + image_name + ) + def test_framework_name_from_image_rl(): image_name = "123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-rl-mxnet:toolkit1.1-gpu-py3" diff --git a/tests/unit/test_tf_estimator.py b/tests/unit/test_tf_estimator.py index 9b4c5333df..f07612eb5f 100644 --- a/tests/unit/test_tf_estimator.py +++ b/tests/unit/test_tf_estimator.py @@ -924,7 +924,7 @@ def test_script_mode_tensorboard( sagemaker_session=sagemaker_session, train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, - framework_version="some_version", + framework_version="1.0", script_mode=True, ) popen().poll.return_value = None