From 0ab23ca8ddceab490cd6a0e0a68db682a2c80489 Mon Sep 17 00:00:00 2001 From: Lauren Yu <6631887+laurenyu@users.noreply.github.com> Date: Tue, 14 Jul 2020 10:04:26 -0700 Subject: [PATCH 1/2] feature: add support for Amazon algorithms in image_uris.retrieve() This also adds configuration for Factorization Machines. --- .../factorization-machines.json | 32 +++++++ .../image_uri_config/tensorflow.json | 45 ++++------ src/sagemaker/image_uris.py | 42 +++++++-- tests/unit/sagemaker/image_uris/__init__.py | 13 +++ .../sagemaker/image_uris/expected_uris.py | 36 ++++++++ tests/unit/sagemaker/image_uris/test_algos.py | 85 +++++++++++++++++++ .../image_uris/test_dlc_frameworks.py | 23 ++--- .../sagemaker/image_uris/test_retrieve.py | 25 +++++- 8 files changed, 242 insertions(+), 59 deletions(-) create mode 100644 src/sagemaker/image_uri_config/factorization-machines.json create mode 100644 tests/unit/sagemaker/image_uris/__init__.py create mode 100644 tests/unit/sagemaker/image_uris/expected_uris.py create mode 100644 tests/unit/sagemaker/image_uris/test_algos.py diff --git a/src/sagemaker/image_uri_config/factorization-machines.json b/src/sagemaker/image_uri_config/factorization-machines.json new file mode 100644 index 0000000000..56f425c0f3 --- /dev/null +++ b/src/sagemaker/image_uri_config/factorization-machines.json @@ -0,0 +1,32 @@ +{ + "scope": ["inference", "training"], + "versions": { + "1": { + "registries": { + "ap-east-1": "286214385809", + "ap-northeast-1": "351501993468", + "ap-northeast-2": "835164637446", + "ap-south-1": "991648021394", + "ap-southeast-1": "475088953585", + "ap-southeast-2": "712309505854", + "ca-central-1": "469771592824", + "cn-north-1": "390948362332", + "cn-northwest-1": "387376663083", + "eu-central-1": "664544806723", + "eu-north-1": "669576153137", + "eu-west-1": "438346466558", + "eu-west-2": "644912444149", + "eu-west-3": "749696950732", + "me-south-1": "249704162688", + "sa-east-1": "855470959533", + "us-east-1": "382416733822", + "us-east-2": "404615174143", + "us-gov-west-1": "226302683700", + "us-iso-east-1": "490574956308", + "us-west-1": "632365934929", + "us-west-2": "174872318107" + }, + "repository": "factorization-machines" + } + } +} diff --git a/src/sagemaker/image_uri_config/tensorflow.json b/src/sagemaker/image_uri_config/tensorflow.json index 196c4fbcc5..2e5b628d06 100644 --- a/src/sagemaker/image_uri_config/tensorflow.json +++ b/src/sagemaker/image_uri_config/tensorflow.json @@ -765,8 +765,7 @@ "us-west-1": "520713654638", "us-west-2": "520713654638" }, - "repository": "sagemaker-tensorflow-serving", - "py_versions": [] + "repository": "sagemaker-tensorflow-serving" }, "1.12.0": { "registries": { @@ -793,8 +792,7 @@ "us-west-1": "520713654638", "us-west-2": "520713654638" }, - "repository": "sagemaker-tensorflow-serving", - "py_versions": [] + "repository": "sagemaker-tensorflow-serving" }, "1.13.0": { "registries": { @@ -821,8 +819,7 @@ "us-west-1": "763104351884", "us-west-2": "763104351884" }, - "repository": "tensorflow-inference", - "py_versions": [] + "repository": "tensorflow-inference" }, "1.14.0": { "registries": { @@ -849,8 +846,7 @@ "us-west-1": "763104351884", "us-west-2": "763104351884" }, - "repository": "tensorflow-inference", - "py_versions": [] + "repository": "tensorflow-inference" }, "1.15.0": { "registries": { @@ -877,8 +873,7 @@ "us-west-1": "763104351884", "us-west-2": "763104351884" }, - "repository": "tensorflow-inference", - "py_versions": [] + "repository": "tensorflow-inference" }, "1.15.2": { "registries": { @@ -905,8 +900,7 @@ "us-west-1": "763104351884", "us-west-2": "763104351884" }, - "repository": "tensorflow-inference", - "py_versions": [] + "repository": "tensorflow-inference" }, "2.0.0": { "registries": { @@ -933,8 +927,7 @@ "us-west-1": "763104351884", "us-west-2": "763104351884" }, - "repository": "tensorflow-inference", - "py_versions": [] + "repository": "tensorflow-inference" }, "2.0.1": { "registries": { @@ -961,8 +954,7 @@ "us-west-1": "763104351884", "us-west-2": "763104351884" }, - "repository": "tensorflow-inference", - "py_versions": [] + "repository": "tensorflow-inference" }, "2.1.0": { "registries": { @@ -989,8 +981,7 @@ "us-west-1": "763104351884", "us-west-2": "763104351884" }, - "repository": "tensorflow-inference", - "py_versions": [] + "repository": "tensorflow-inference" } } }, @@ -1059,8 +1050,7 @@ "us-west-1": "520713654638", "us-west-2": "520713654638" }, - "repository": "sagemaker-tensorflow-serving-eia", - "py_versions": [] + "repository": "sagemaker-tensorflow-serving-eia" }, "1.12.0": { "registries": { @@ -1087,8 +1077,7 @@ "us-west-1": "520713654638", "us-west-2": "520713654638" }, - "repository": "sagemaker-tensorflow-serving-eia", - "py_versions": [] + "repository": "sagemaker-tensorflow-serving-eia" }, "1.13.0": { "registries": { @@ -1115,8 +1104,7 @@ "us-west-1": "520713654638", "us-west-2": "520713654638" }, - "repository": "sagemaker-tensorflow-serving-eia", - "py_versions": [] + "repository": "sagemaker-tensorflow-serving-eia" }, "1.14.0": { "registries": { @@ -1143,8 +1131,7 @@ "us-west-1": "763104351884", "us-west-2": "763104351884" }, - "repository": "tensorflow-inference-eia", - "py_versions": [] + "repository": "tensorflow-inference-eia" }, "1.15.0": { "registries": { @@ -1171,8 +1158,7 @@ "us-west-1": "763104351884", "us-west-2": "763104351884" }, - "repository": "tensorflow-inference-eia", - "py_versions": [] + "repository": "tensorflow-inference-eia" }, "2.0.0": { "registries": { @@ -1199,8 +1185,7 @@ "us-west-1": "763104351884", "us-west-2": "763104351884" }, - "repository": "tensorflow-inference-eia", - "py_versions": [] + "repository": "tensorflow-inference-eia" } } } diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index 66de9f066e..62c666f5ef 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -36,10 +36,10 @@ def retrieve( """Retrieves the ECR URI for the Docker image matching the given arguments. Args: - framework (str): The name of the framework. + framework (str): The name of the framework or algorithm. region (str): The AWS region. - version (str): The framework version. This is required if there is - more than one supported version for the given framework. + version (str): The framework or algorithm version. This is required if there is + more than one supported version for the given framework or algorithm. py_version (str): The Python version. This is required if there is more than one supported Python version for the given framework version. instance_type (str): The SageMaker instance type. For supported types, see @@ -58,7 +58,9 @@ def retrieve( ValueError: If the combination of arguments specified is not supported. """ config = _config_for_framework_and_scope(framework, image_scope, accelerator_type) - version_config = config["versions"][_version_for_config(version, config, framework)] + + version = _validate_version_and_set_if_needed(version, config, framework) + version_config = config["versions"][_version_for_config(version, config)] py_version = _validate_py_version_and_set_if_needed(py_version, version_config) version_config = version_config.get(py_version) or version_config @@ -67,7 +69,7 @@ def retrieve( hostname = utils._botocore_resolver().construct_endpoint("ecr", region)["hostname"] repo = version_config["repository"] - tag = _format_tag(version, _processor(instance_type, config["processors"]), py_version) + tag = _format_tag(version, _processor(instance_type, config.get("processors")), py_version) return ECR_URI_TEMPLATE.format(registry=registry, hostname=hostname, repository=repo, tag=tag) @@ -94,13 +96,28 @@ def config_for_framework(framework): return json.load(f) -def _version_for_config(version, config, framework): +def _validate_version_and_set_if_needed(version, config, framework): + """Checks if the framework/algorithm version is one of the supported versions.""" + available_versions = list(config["versions"].keys()) + + if len(available_versions) == 1: + logger.info( + "Defaulting to only available framework/algorithm version: %s", available_versions[0] + ) + return available_versions[0] + + available_versions += list(config.get("version_aliases", {}).keys()) + _validate_arg("{} version".format(framework), version, available_versions) + + return version + + +def _version_for_config(version, config): """Returns the version string for retrieving a framework version's specific config.""" if "version_aliases" in config: if version in config["version_aliases"].keys(): return config["version_aliases"][version] - _validate_arg("{} version".format(framework), version, config["versions"].keys()) return version @@ -112,6 +129,10 @@ def _registry_from_region(region, registry_dict): def _processor(instance_type, available_processors): """Returns the processor type for the given instance type.""" + if not available_processors: + logger.info("Ignoring unnecessary instance type: %s.", instance_type) + return None + if instance_type.startswith("local"): processor = "cpu" if instance_type == "local" else "gpu" elif not instance_type.startswith("ml."): @@ -129,9 +150,12 @@ def _processor(instance_type, available_processors): def _validate_py_version_and_set_if_needed(py_version, version_config): """Checks if the Python version is one of the supported versions.""" - available_versions = version_config.get("py_versions", version_config.keys()) + if "repository" in version_config: + available_versions = version_config.get("py_versions") + else: + available_versions = list(version_config.keys()) - if len(available_versions) == 0: + if not available_versions: if py_version: logger.info("Ignoring unnecessary Python version: %s.", py_version) return None diff --git a/tests/unit/sagemaker/image_uris/__init__.py b/tests/unit/sagemaker/image_uris/__init__.py new file mode 100644 index 0000000000..ec1e80a0b4 --- /dev/null +++ b/tests/unit/sagemaker/image_uris/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import diff --git a/tests/unit/sagemaker/image_uris/expected_uris.py b/tests/unit/sagemaker/image_uris/expected_uris.py new file mode 100644 index 0000000000..d0de1f7b45 --- /dev/null +++ b/tests/unit/sagemaker/image_uris/expected_uris.py @@ -0,0 +1,36 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +ALTERNATE_DOMAINS = { + "cn-north-1": "amazonaws.com.cn", + "cn-northwest-1": "amazonaws.com.cn", + "us-iso-east-1": "c2s.ic.gov", +} +DOMAIN = "amazonaws.com" +IMAGE_URI_FORMAT = "{}.dkr.ecr.{}.{}/{}:{}" +REGION = "us-west-2" + + +def framework_uri(repo, fw_version, account, py_version=None, processor="cpu", region=REGION): + domain = ALTERNATE_DOMAINS.get(region, DOMAIN) + tag = "{}-{}".format(fw_version, processor) + if py_version: + tag = "-".join((tag, py_version)) + + return IMAGE_URI_FORMAT.format(account, region, domain, repo, tag) + + +def algo_uri(algo, account, region): + domain = ALTERNATE_DOMAINS.get(region, DOMAIN) + return IMAGE_URI_FORMAT.format(account, region, domain, algo, 1) diff --git a/tests/unit/sagemaker/image_uris/test_algos.py b/tests/unit/sagemaker/image_uris/test_algos.py new file mode 100644 index 0000000000..0d59dc0490 --- /dev/null +++ b/tests/unit/sagemaker/image_uris/test_algos.py @@ -0,0 +1,85 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import boto3 + +from sagemaker import image_uris +from tests.unit.sagemaker.image_uris import expected_uris + +ALGO_REGIONS_AND_ACCOUNTS = ( + { + "algorithms": ( + "pca", + "kmeans", + "linear-learner", + "factorization-machines", + "ntm", + "randomcutforest", + "knn", + "object2vec", + "ipinsights", + ), + "accounts": { + "ap-east-1": "286214385809", + "ap-northeast-1": "351501993468", + "ap-northeast-2": "835164637446", + "ap-south-1": "991648021394", + "ap-southeast-1": "475088953585", + "ap-southeast-2": "712309505854", + "ca-central-1": "469771592824", + "cn-north-1": "390948362332", + "cn-northwest-1": "387376663083", + "eu-central-1": "664544806723", + "eu-north-1": "669576153137", + "eu-west-1": "438346466558", + "eu-west-2": "644912444149", + "eu-west-3": "749696950732", + "me-south-1": "249704162688", + "sa-east-1": "855470959533", + "us-east-1": "382416733822", + "us-east-2": "404615174143", + "us-gov-west-1": "226302683700", + "us-iso-east-1": "490574956308", + "us-west-1": "632365934929", + "us-west-2": "174872318107", + }, + }, +) + +IMAGE_URI_FORMAT = "{}.dkr.ecr.{}.{}/{}:1" + + +def _regions(): + boto_session = boto3.Session() + for partition in boto_session.get_available_partitions(): + for region in boto_session.get_available_regions("sagemaker", partition_name=partition): + yield region + + +def _accounts_for_algo(algo): + for algo_account_dict in ALGO_REGIONS_AND_ACCOUNTS: + if algo in algo_account_dict["algorithms"]: + return algo_account_dict["accounts"] + + return {} + + +def test_factorization_machines(): + algo = "factorization-machines" + accounts = _accounts_for_algo(algo) + + for region in _regions(): + for scope in ("training", "inference"): + uri = image_uris.retrieve(algo, region, image_scope=scope) + assert expected_uris.algo_uri(algo, accounts[region], region) == uri diff --git a/tests/unit/sagemaker/image_uris/test_dlc_frameworks.py b/tests/unit/sagemaker/image_uris/test_dlc_frameworks.py index 4cbc4e6533..1e3beeb288 100644 --- a/tests/unit/sagemaker/image_uris/test_dlc_frameworks.py +++ b/tests/unit/sagemaker/image_uris/test_dlc_frameworks.py @@ -15,14 +15,8 @@ from packaging.version import Version from sagemaker import image_uris +from tests.unit.sagemaker.image_uris import expected_uris -ALTERNATE_DOMAINS = { - "cn-north-1": "amazonaws.com.cn", - "cn-northwest-1": "amazonaws.com.cn", - "us-iso-east-1": "c2s.ic.gov", -} -DOMAIN = "amazonaws.com" -IMAGE_URI_FORMAT = "{}.dkr.ecr.{}.{}/{}:{}" INSTANCE_TYPES_AND_PROCESSORS = (("ml.c4.xlarge", "cpu"), ("ml.p2.xlarge", "gpu")) REGION = "us-west-2" @@ -46,13 +40,6 @@ } -def _expected_uri(repo, fw_version, account, py_version=None, processor="cpu", region=REGION): - domain = ALTERNATE_DOMAINS.get(region, DOMAIN) - tag = "-".join([x for x in (fw_version, processor, py_version) if x]) - - return IMAGE_URI_FORMAT.format(account, region, domain, repo, tag) - - def test_chainer(chainer_version, chainer_py_version): for instance_type, processor in INSTANCE_TYPES_AND_PROCESSORS: for scope in ("training", "inference"): @@ -64,7 +51,7 @@ def test_chainer(chainer_version, chainer_py_version): instance_type=instance_type, image_scope=scope, ) - expected = _expected_uri( + expected = expected_uris.framework_uri( repo="sagemaker-chainer", fw_version=chainer_version, py_version=chainer_py_version, @@ -82,7 +69,7 @@ def test_chainer(chainer_version, chainer_py_version): instance_type="ml.c4.xlarge", image_scope="training", ) - expected = _expected_uri( + expected = expected_uris.framework_uri( repo="sagemaker-chainer", fw_version=chainer_version, py_version=chainer_py_version, @@ -142,7 +129,7 @@ def _expected_tf_training_uri(tf_training_version, py_version, processor="cpu", else: account = DLC_ACCOUNT if region == REGION else DLC_ALTERNATE_REGION_ACCOUNTS[region] - return _expected_uri( + return expected_uris.framework_uri( repo, tf_training_version, account, @@ -221,7 +208,7 @@ def _expected_tf_inference_uri(tf_inference_version, processor="cpu", region=REG else: account = DLC_ACCOUNT if region == REGION else DLC_ALTERNATE_REGION_ACCOUNTS[region] - return _expected_uri( + return expected_uris.framework_uri( repo, tf_inference_version, account, py_version, processor=processor, region=region, ) diff --git a/tests/unit/sagemaker/image_uris/test_retrieve.py b/tests/unit/sagemaker/image_uris/test_retrieve.py index 43f397d5d3..0dd2e017dc 100644 --- a/tests/unit/sagemaker/image_uris/test_retrieve.py +++ b/tests/unit/sagemaker/image_uris/test_retrieve.py @@ -29,6 +29,11 @@ "repository": "dummy", "py_versions": ["py3", "py37"], }, + "1.1.0": { + "registries": {"us-west-2": "123412341234"}, + "repository": "dummy", + "py_versions": ["py3", "py37"], + }, }, } @@ -123,6 +128,22 @@ def test_retrieve_aliased_version(config_for_framework): assert "123412341234.dkr.ecr.us-west-2.amazonaws.com/dummy:{}-cpu-py3".format(version) == uri +@patch("sagemaker.image_uris.config_for_framework") +def test_retrieve_default_version_if_possible(config_for_framework): + config = copy.deepcopy(BASE_CONFIG) + del config["versions"]["1.1.0"] + config_for_framework.return_value = config + + uri = image_uris.retrieve( + framework="useless-string", + py_version="py3", + instance_type="ml.c4.xlarge", + region="us-west-2", + image_scope="training", + ) + assert "123412341234.dkr.ecr.us-west-2.amazonaws.com/dummy:1.0.0-cpu-py3" == uri + + @patch("sagemaker.image_uris.config_for_framework", return_value=BASE_CONFIG) def test_retrieve_unsupported_version(config_for_framework): with pytest.raises(ValueError) as e: @@ -136,7 +157,7 @@ def test_retrieve_unsupported_version(config_for_framework): ) assert "Unsupported some-framework version: 1." in str(e.value) - assert "Supported some-framework version(s): 1.0.0." in str(e.value) + assert "Supported some-framework version(s): 1.0.0, 1.1.0." in str(e.value) with pytest.raises(ValueError) as e: image_uris.retrieve( @@ -148,7 +169,7 @@ def test_retrieve_unsupported_version(config_for_framework): ) assert "Unsupported some-framework version: None." in str(e.value) - assert "Supported some-framework version(s): 1.0.0." in str(e.value) + assert "Supported some-framework version(s): 1.0.0, 1.1.0." in str(e.value) @patch("sagemaker.image_uris.config_for_framework", return_value=BASE_CONFIG) From ad61d985b3d137dc2fb11ca8e6ae387d2eaae87f Mon Sep 17 00:00:00 2001 From: Lauren Yu <6631887+laurenyu@users.noreply.github.com> Date: Tue, 14 Jul 2020 16:30:35 -0700 Subject: [PATCH 2/2] address PR comment --- src/sagemaker/image_uris.py | 9 +++++++-- tests/unit/sagemaker/image_uris/test_retrieve.py | 13 ++++++++++++- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index 62c666f5ef..773dda7039 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -101,9 +101,14 @@ def _validate_version_and_set_if_needed(version, config, framework): available_versions = list(config["versions"].keys()) if len(available_versions) == 1: - logger.info( - "Defaulting to only available framework/algorithm version: %s", available_versions[0] + log_message = "Defaulting to the only supported framework/algorithm version: {}.".format( + available_versions[0] ) + if version and version != available_versions[0]: + logger.warning("%s Ignoring framework/algorithm version: %s.", log_message, version) + elif not version: + logger.info(log_message) + return available_versions[0] available_versions += list(config.get("version_aliases", {}).keys()) diff --git a/tests/unit/sagemaker/image_uris/test_retrieve.py b/tests/unit/sagemaker/image_uris/test_retrieve.py index 0dd2e017dc..3b405ae7b8 100644 --- a/tests/unit/sagemaker/image_uris/test_retrieve.py +++ b/tests/unit/sagemaker/image_uris/test_retrieve.py @@ -129,7 +129,7 @@ def test_retrieve_aliased_version(config_for_framework): @patch("sagemaker.image_uris.config_for_framework") -def test_retrieve_default_version_if_possible(config_for_framework): +def test_retrieve_default_version_if_possible(config_for_framework, caplog): config = copy.deepcopy(BASE_CONFIG) del config["versions"]["1.1.0"] config_for_framework.return_value = config @@ -143,6 +143,17 @@ def test_retrieve_default_version_if_possible(config_for_framework): ) assert "123412341234.dkr.ecr.us-west-2.amazonaws.com/dummy:1.0.0-cpu-py3" == uri + uri = image_uris.retrieve( + framework="useless-string", + version="invalid-version", + py_version="py3", + instance_type="ml.c4.xlarge", + region="us-west-2", + image_scope="training", + ) + assert "123412341234.dkr.ecr.us-west-2.amazonaws.com/dummy:1.0.0-cpu-py3" == uri + assert "Ignoring framework/algorithm version: invalid-version." in caplog.text + @patch("sagemaker.image_uris.config_for_framework", return_value=BASE_CONFIG) def test_retrieve_unsupported_version(config_for_framework):