diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index f2ab6fd4e1..e953a731b8 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -1,4 +1,4 @@ -# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of @@ -56,22 +56,25 @@ VALID_PY_VERSIONS = ["py2", "py3"] VALID_EIA_FRAMEWORKS = ["tensorflow", "tensorflow-serving", "mxnet", "mxnet-serving"] VALID_ACCOUNTS_BY_REGION = {"us-gov-west-1": "246785580436", "us-iso-east-1": "744548109606"} +ASIMOV_VALID_ACCOUNTS_BY_REGION = {"us-iso-east-1": "886529160074"} OPT_IN_ACCOUNTS_BY_REGION = {"ap-east-1": "057415533634"} ASIMOV_OPT_IN_ACCOUNTS_BY_REGION = {"ap-east-1": "871362719292"} MERGED_FRAMEWORKS_REPO_MAP = { "tensorflow-scriptmode": "tensorflow-training", - "mxnet": "mxnet-training", "tensorflow-serving": "tensorflow-inference", "tensorflow-serving-eia": "tensorflow-inference-eia", + "mxnet": "mxnet-training", + "mxnet-serving": "mxnet-inference", "mxnet-serving-eia": "mxnet-inference-eia", } MERGED_FRAMEWORKS_LOWEST_VERSIONS = { "tensorflow-scriptmode": [1, 13, 1], - "mxnet": [1, 4, 1], "tensorflow-serving": [1, 13, 0], "tensorflow-serving-eia": [1, 14, 0], + "mxnet": [1, 4, 1], + "mxnet-serving": [1, 4, 1], "mxnet-serving-eia": [1, 4, 1], } @@ -116,13 +119,9 @@ def _using_merged_images(region, framework, py_version, framework_version): is_py3 = py_version == "py3" or py_version is None is_merged_versions = _is_merged_versions(framework, framework_version) return ( - (not is_gov_region) + ((not is_gov_region) or region in ASIMOV_VALID_ACCOUNTS_BY_REGION) and is_merged_versions - and ( - is_py3 - or _is_tf_14_or_later(framework, framework_version) - or _is_mxnet_serving_141_or_later(framework, framework_version) - ) + and (is_py3 or _is_tf_14_or_later(framework, framework_version)) ) @@ -140,24 +139,6 @@ def _is_tf_14_or_later(framework, framework_version): ) -def _is_mxnet_serving_141_or_later(framework, framework_version): - """ - Args: - framework: - framework_version: - """ - asimov_lowest_mxnet = [1, 4, 1] - - version = [int(s) for s in framework_version.split(".")] - - if len(version) == 2: - version.append(0) - - return ( - framework.startswith("mxnet-serving") and version >= asimov_lowest_mxnet[0 : len(version)] - ) - - def _registry_id(region, framework, py_version, account, framework_version): """ Args: @@ -171,6 +152,8 @@ def _registry_id(region, framework, py_version, account, framework_version): if _using_merged_images(region, framework, py_version, framework_version): if region in ASIMOV_OPT_IN_ACCOUNTS_BY_REGION: return ASIMOV_OPT_IN_ACCOUNTS_BY_REGION.get(region) + if region in ASIMOV_VALID_ACCOUNTS_BY_REGION: + return ASIMOV_VALID_ACCOUNTS_BY_REGION.get(region) return "763104351884" if region in OPT_IN_ACCOUNTS_BY_REGION: return OPT_IN_ACCOUNTS_BY_REGION.get(region) diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index 53fae0ae0b..494720779b 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -1,4 +1,4 @@ -# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of @@ -150,31 +150,18 @@ def test_tf_eia_images(): image_uri = fw_utils.create_image_uri( "us-west-2", "tensorflow-serving", - "ml.p3.2xlarge", + "ml.m4.xlarge", "1.14.0", "py3", accelerator_type="ml.eia1.medium", ) assert ( image_uri - == "763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-inference-eia:1.14.0-gpu" + == "763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-inference-eia:1.14.0-cpu" ) def test_mxnet_eia_images(): - image_uri = fw_utils.create_image_uri( - "us-west-2", - "mxnet-serving", - "ml.p3.2xlarge", - "1.4.1", - "py2", - accelerator_type="ml.eia1.medium", - ) - assert ( - image_uri - == "763104351884.dkr.ecr.us-west-2.amazonaws.com/mxnet-inference-eia:1.4.1-gpu-py2" - ) - image_uri = fw_utils.create_image_uri( "us-east-1", "mxnet-serving", @@ -218,10 +205,7 @@ def test_create_image_uri_merged(): image_uri = fw_utils.create_image_uri( "us-west-2", "mxnet-serving", "ml.c4.2xlarge", "1.4.1", "py3" ) - assert ( - image_uri - == "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet-serving:1.4.1-cpu-py3" - ) + assert image_uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/mxnet-inference:1.4.1-cpu-py3" image_uri = fw_utils.create_image_uri( "us-west-2", @@ -265,6 +249,49 @@ def test_create_image_uri_merged_py2(): ) +def test_create_image_uri_merged_gov_regions(): + image_uri = fw_utils.create_image_uri( + "us-iso-east-1", "tensorflow-scriptmode", "ml.m4.xlarge", "1.13.1", "py3" + ) + assert ( + image_uri + == "886529160074.dkr.ecr.us-iso-east-1.c2s.ic.gov/tensorflow-training:1.13.1-cpu-py3" + ) + + image_uri = fw_utils.create_image_uri( + "us-iso-east-1", "tensorflow-scriptmode", "ml.p3.2xlarge", "1.14", "py2" + ) + assert ( + image_uri + == "886529160074.dkr.ecr.us-iso-east-1.c2s.ic.gov/tensorflow-training:1.14-gpu-py2" + ) + + image_uri = fw_utils.create_image_uri( + "us-iso-east-1", "tensorflow-serving", "ml.m4.xlarge", "1.13.0" + ) + assert ( + image_uri == "886529160074.dkr.ecr.us-iso-east-1.c2s.ic.gov/tensorflow-inference:1.13.0-cpu" + ) + + image_uri = fw_utils.create_image_uri("us-iso-east-1", "mxnet", "ml.p3.2xlarge", "1.4.1", "py3") + assert image_uri == "886529160074.dkr.ecr.us-iso-east-1.c2s.ic.gov/mxnet-training:1.4.1-gpu-py3" + + image_uri = fw_utils.create_image_uri( + "us-iso-east-1", "mxnet-serving", "ml.c4.2xlarge", "1.4.1", "py3" + ) + assert ( + image_uri == "886529160074.dkr.ecr.us-iso-east-1.c2s.ic.gov/mxnet-inference:1.4.1-cpu-py3" + ) + + image_uri = fw_utils.create_image_uri( + "us-iso-east-1", "mxnet-serving", "ml.c4.2xlarge", "1.3.1", "py3" + ) + assert ( + image_uri + == "744548109606.dkr.ecr.us-iso-east-1.c2s.ic.gov/sagemaker-mxnet-serving:1.3.1-cpu-py3" + ) + + def test_create_image_uri_accelerator_tf(): image_uri = fw_utils.create_image_uri( MOCK_REGION, "tensorflow", "ml.p3.2xlarge", "1.0", "py3", accelerator_type="ml.eia1.medium"