Skip to content

feature: start new module for retrieving prebuilt SageMaker image URIs #1701

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
recursive-include src/sagemaker *
recursive-include src/sagemaker *.py

include src/sagemaker/image_uri_config/*.json

include VERSION
include LICENSE.txt
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def read_version():
packages=find_packages("src"),
package_dir={"": "src"},
py_modules=[os.path.splitext(os.path.basename(path))[0] for path in glob("src/*.py")],
include_package_data=True,
long_description=read("README.rst"),
author="Amazon Web Services",
url="https://github.com/aws/sagemaker-python-sdk/",
Expand Down
94 changes: 94 additions & 0 deletions src/sagemaker/image_uri_config/chainer.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
{
"processors": ["cpu", "gpu"],
"version_aliases": {
"4.0": "4.0.0",
"4.1": "4.1.0",
"5.0": "5.0.0"
},
"versions": {
"4.0.0": {
"registries": {
"ap-east-1": "057415533634",
"ap-northeast-1": "520713654638",
"ap-northeast-2": "520713654638",
"ap-south-1": "520713654638",
"ap-southeast-1": "520713654638",
"ap-southeast-2": "520713654638",
"ca-central-1": "520713654638",
"cn-north-1": "422961961927",
"cn-northwest-1": "423003514399",
"eu-central-1": "520713654638",
"eu-north-1": "520713654638",
"eu-west-1": "520713654638",
"eu-west-2": "520713654638",
"eu-west-3": "520713654638",
"me-south-1": "724002660598",
"sa-east-1": "520713654638",
"us-east-1": "520713654638",
"us-east-2": "520713654638",
"us-gov-west-1": "246785580436",
"us-iso-east-1": "744548109606",
"us-west-1": "520713654638",
"us-west-2": "520713654638"
},
"repository": "sagemaker-chainer",
"py_versions": ["py2", "py3"]
},
"4.1.0": {
"registries": {
"ap-east-1": "057415533634",
"ap-northeast-1": "520713654638",
"ap-northeast-2": "520713654638",
"ap-south-1": "520713654638",
"ap-southeast-1": "520713654638",
"ap-southeast-2": "520713654638",
"ca-central-1": "520713654638",
"cn-north-1": "422961961927",
"cn-northwest-1": "423003514399",
"eu-central-1": "520713654638",
"eu-north-1": "520713654638",
"eu-west-1": "520713654638",
"eu-west-2": "520713654638",
"eu-west-3": "520713654638",
"me-south-1": "724002660598",
"sa-east-1": "520713654638",
"us-east-1": "520713654638",
"us-east-2": "520713654638",
"us-gov-west-1": "246785580436",
"us-iso-east-1": "744548109606",
"us-west-1": "520713654638",
"us-west-2": "520713654638"
},
"repository": "sagemaker-chainer",
"py_versions": ["py2", "py3"]
},
"5.0.0": {
"registries": {
"ap-east-1": "057415533634",
"ap-northeast-1": "520713654638",
"ap-northeast-2": "520713654638",
"ap-south-1": "520713654638",
"ap-southeast-1": "520713654638",
"ap-southeast-2": "520713654638",
"ca-central-1": "520713654638",
"cn-north-1": "422961961927",
"cn-northwest-1": "423003514399",
"eu-central-1": "520713654638",
"eu-north-1": "520713654638",
"eu-west-1": "520713654638",
"eu-west-2": "520713654638",
"eu-west-3": "520713654638",
"me-south-1": "724002660598",
"sa-east-1": "520713654638",
"us-east-1": "520713654638",
"us-east-2": "520713654638",
"us-gov-west-1": "246785580436",
"us-iso-east-1": "744548109606",
"us-west-1": "520713654638",
"us-west-2": "520713654638"
},
"repository": "sagemaker-chainer",
"py_versions": ["py2", "py3"]
}
}
}
127 changes: 127 additions & 0 deletions src/sagemaker/image_uris.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Functions for generating ECR image URIs for pre-built SageMaker Docker images."""
from __future__ import absolute_import

import json
import os

from sagemaker import utils

ECR_URI_TEMPLATE = "{registry}.dkr.{hostname}/{repository}:{tag}"


def retrieve(framework, region, version=None, py_version=None, instance_type=None):
"""Retrieves the ECR URI for the Docker image matching the given arguments.

Args:
framework (str): The name of the framework.
region (str): The AWS region.
version (str): The framework version. This is required if there is
more than one supported version for the given framework.
py_version (str): The Python version. This is required if there is
more than one supported Python version for the given framework version.
instance_type (str): The SageMaker instance type. For supported types, see
https://aws.amazon.com/sagemaker/pricing/instance-types. This is required if
there are different images for different processor types.

Returns:
str: the ECR URI for the corresponding SageMaker Docker image.

Raises:
ValueError: If the framework version, Python version, processor type, or region is
not supported given the other arguments.
"""
config = config_for_framework(framework)
version_config = config["versions"][_version_for_config(version, config, framework)]

registry = _registry_from_region(region, version_config["registries"])
hostname = utils._botocore_resolver().construct_endpoint("ecr", region)["hostname"]

repo = version_config["repository"]

_validate_py_version(py_version, version_config["py_versions"], framework, version)
tag = "{}-{}-{}".format(version, _processor(instance_type, config["processors"]), py_version)

return ECR_URI_TEMPLATE.format(registry=registry, hostname=hostname, repository=repo, tag=tag)


def config_for_framework(framework):
"""Loads the JSON config for the given framework."""
fname = os.path.join(os.path.dirname(__file__), "image_uri_config", "{}.json".format(framework))
with open(fname) as f:
return json.load(f)


def _version_for_config(version, config, framework):
"""Returns the version string for retrieving a framework version's specific config."""
if "version_aliases" in config:
if version in config["version_aliases"].keys():
return config["version_aliases"][version]

available_versions = config["versions"].keys()
if version in available_versions:
return version

raise ValueError(
"Unsupported {} version: {}. "
"You may need to upgrade your SDK version (pip install -U sagemaker) for newer versions. "
"Supported version(s): {}.".format(framework, version, ", ".join(available_versions))
)


def _registry_from_region(region, registry_dict):
"""Returns the ECR registry (AWS account number) for the given region."""
available_regions = registry_dict.keys()
if region not in available_regions:
raise ValueError(
"Unsupported region: {}. You may need to upgrade "
"your SDK version (pip install -U sagemaker) for newer regions. "
"Supported region(s): {}.".format(region, ", ".join(available_regions))
)

return registry_dict[region]


def _processor(instance_type, available_processors):
"""Returns the processor type for the given instance type."""
if instance_type.startswith("local"):
processor = "cpu" if instance_type == "local" else "gpu"
elif not instance_type.startswith("ml."):
raise ValueError(
"Invalid SageMaker instance type: {}. See: "
"https://aws.amazon.com/sagemaker/pricing/instance-types".format(instance_type)
)
else:
family = instance_type.split(".")[1]
processor = "gpu" if family[0] in ("g", "p") else "cpu"

if processor in available_processors:
return processor

raise ValueError(
"Unsupported processor type: {} (for {}). "
"Supported type(s): {}.".format(processor, instance_type, ", ".join(available_processors))
)


def _validate_py_version(py_version, available_versions, framework, fw_version):
"""Checks if the Python version is one of the supported versions."""
if py_version not in available_versions:
raise ValueError(
"Unsupported Python version for {} {}: {}. You may need to upgrade "
"your SDK version (pip install -U sagemaker) for newer versions. "
"Supported Python version(s): {}.".format(
framework, fw_version, py_version, ", ".join(available_versions)
)
)
16 changes: 10 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from botocore.config import Config
from packaging.version import Version

from sagemaker import Session, utils
from sagemaker import Session, image_uris, utils
from sagemaker.local import LocalSession
from sagemaker.rl import RLEstimator

Expand Down Expand Up @@ -110,11 +110,6 @@ def custom_bucket_name(boto_session):
return "{}-{}-{}".format(CUSTOM_BUCKET_NAME_PREFIX, region, account)


@pytest.fixture(scope="module", params=["4.0", "4.0.0", "4.1", "4.1.0", "5.0", "5.0.0"])
def chainer_version(request):
return request.param


@pytest.fixture(scope="module", params=["py2", "py3"])
def chainer_py_version(request):
return request.param
Expand Down Expand Up @@ -405,3 +400,12 @@ def pytest_generate_tests(metafunc):
):
params.append("ml.p2.xlarge")
metafunc.parametrize("instance_type", params, scope="session")

for fw in ("chainer",):
fixture_name = "{}_version".format(fw)
if fixture_name in metafunc.fixturenames:
config = image_uris.config_for_framework(fw)
versions = list(config["versions"].keys()) + list(
config.get("version_aliases", {}).keys()
)
metafunc.parametrize(fixture_name, versions, scope="session")
69 changes: 69 additions & 0 deletions tests/unit/sagemaker/image_uris/test_frameworks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

from sagemaker import image_uris

ACCOUNT = "520713654638"
DOMAIN = "amazonaws.com"
IMAGE_URI_FORMAT = "{}.dkr.ecr.{}.{}/{}:{}-{}-{}"
REGION = "us-west-2"

ALTERNATE_REGION_DOMAIN_AND_ACCOUNTS = (
("ap-east-1", DOMAIN, "057415533634"),
("cn-north-1", "amazonaws.com.cn", "422961961927"),
("cn-northwest-1", "amazonaws.com.cn", "423003514399"),
("me-south-1", DOMAIN, "724002660598"),
("us-gov-west-1", DOMAIN, "246785580436"),
("us-iso-east-1", "c2s.ic.gov", "744548109606"),
)


def _expected_uri(
repo, fw_version, processor, py_version, account=ACCOUNT, region=REGION, domain=DOMAIN
):
return IMAGE_URI_FORMAT.format(account, region, domain, repo, fw_version, processor, py_version)


def test_chainer(chainer_version, chainer_py_version):
for instance_type, processor in (("ml.c4.xlarge", "cpu"), ("ml.p2.xlarge", "gpu")):
uri = image_uris.retrieve(
framework="chainer",
region=REGION,
version=chainer_version,
py_version=chainer_py_version,
instance_type=instance_type,
)
expected = _expected_uri(
"sagemaker-chainer", chainer_version, processor, chainer_py_version
)
assert expected == uri

for region, domain, account in ALTERNATE_REGION_DOMAIN_AND_ACCOUNTS:
uri = image_uris.retrieve(
framework="chainer",
region=region,
version=chainer_version,
py_version=chainer_py_version,
instance_type=instance_type,
)
expected = _expected_uri(
"sagemaker-chainer",
chainer_version,
processor,
chainer_py_version,
account=account,
region=region,
domain=domain,
)
assert expected == uri
Loading