Skip to content

Commit 910eebd

Browse files
authored
feature: add support for Amazon algorithms in image_uris.retrieve() (#1709)
This also adds configuration for Factorization Machines.
1 parent cb85792 commit 910eebd

File tree

8 files changed

+258
-59
lines changed

8 files changed

+258
-59
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
{
2+
"scope": ["inference", "training"],
3+
"versions": {
4+
"1": {
5+
"registries": {
6+
"ap-east-1": "286214385809",
7+
"ap-northeast-1": "351501993468",
8+
"ap-northeast-2": "835164637446",
9+
"ap-south-1": "991648021394",
10+
"ap-southeast-1": "475088953585",
11+
"ap-southeast-2": "712309505854",
12+
"ca-central-1": "469771592824",
13+
"cn-north-1": "390948362332",
14+
"cn-northwest-1": "387376663083",
15+
"eu-central-1": "664544806723",
16+
"eu-north-1": "669576153137",
17+
"eu-west-1": "438346466558",
18+
"eu-west-2": "644912444149",
19+
"eu-west-3": "749696950732",
20+
"me-south-1": "249704162688",
21+
"sa-east-1": "855470959533",
22+
"us-east-1": "382416733822",
23+
"us-east-2": "404615174143",
24+
"us-gov-west-1": "226302683700",
25+
"us-iso-east-1": "490574956308",
26+
"us-west-1": "632365934929",
27+
"us-west-2": "174872318107"
28+
},
29+
"repository": "factorization-machines"
30+
}
31+
}
32+
}

src/sagemaker/image_uri_config/tensorflow.json

+15-30
Original file line numberDiff line numberDiff line change
@@ -765,8 +765,7 @@
765765
"us-west-1": "520713654638",
766766
"us-west-2": "520713654638"
767767
},
768-
"repository": "sagemaker-tensorflow-serving",
769-
"py_versions": []
768+
"repository": "sagemaker-tensorflow-serving"
770769
},
771770
"1.12.0": {
772771
"registries": {
@@ -793,8 +792,7 @@
793792
"us-west-1": "520713654638",
794793
"us-west-2": "520713654638"
795794
},
796-
"repository": "sagemaker-tensorflow-serving",
797-
"py_versions": []
795+
"repository": "sagemaker-tensorflow-serving"
798796
},
799797
"1.13.0": {
800798
"registries": {
@@ -821,8 +819,7 @@
821819
"us-west-1": "763104351884",
822820
"us-west-2": "763104351884"
823821
},
824-
"repository": "tensorflow-inference",
825-
"py_versions": []
822+
"repository": "tensorflow-inference"
826823
},
827824
"1.14.0": {
828825
"registries": {
@@ -849,8 +846,7 @@
849846
"us-west-1": "763104351884",
850847
"us-west-2": "763104351884"
851848
},
852-
"repository": "tensorflow-inference",
853-
"py_versions": []
849+
"repository": "tensorflow-inference"
854850
},
855851
"1.15.0": {
856852
"registries": {
@@ -877,8 +873,7 @@
877873
"us-west-1": "763104351884",
878874
"us-west-2": "763104351884"
879875
},
880-
"repository": "tensorflow-inference",
881-
"py_versions": []
876+
"repository": "tensorflow-inference"
882877
},
883878
"1.15.2": {
884879
"registries": {
@@ -905,8 +900,7 @@
905900
"us-west-1": "763104351884",
906901
"us-west-2": "763104351884"
907902
},
908-
"repository": "tensorflow-inference",
909-
"py_versions": []
903+
"repository": "tensorflow-inference"
910904
},
911905
"2.0.0": {
912906
"registries": {
@@ -933,8 +927,7 @@
933927
"us-west-1": "763104351884",
934928
"us-west-2": "763104351884"
935929
},
936-
"repository": "tensorflow-inference",
937-
"py_versions": []
930+
"repository": "tensorflow-inference"
938931
},
939932
"2.0.1": {
940933
"registries": {
@@ -961,8 +954,7 @@
961954
"us-west-1": "763104351884",
962955
"us-west-2": "763104351884"
963956
},
964-
"repository": "tensorflow-inference",
965-
"py_versions": []
957+
"repository": "tensorflow-inference"
966958
},
967959
"2.1.0": {
968960
"registries": {
@@ -989,8 +981,7 @@
989981
"us-west-1": "763104351884",
990982
"us-west-2": "763104351884"
991983
},
992-
"repository": "tensorflow-inference",
993-
"py_versions": []
984+
"repository": "tensorflow-inference"
994985
}
995986
}
996987
},
@@ -1059,8 +1050,7 @@
10591050
"us-west-1": "520713654638",
10601051
"us-west-2": "520713654638"
10611052
},
1062-
"repository": "sagemaker-tensorflow-serving-eia",
1063-
"py_versions": []
1053+
"repository": "sagemaker-tensorflow-serving-eia"
10641054
},
10651055
"1.12.0": {
10661056
"registries": {
@@ -1087,8 +1077,7 @@
10871077
"us-west-1": "520713654638",
10881078
"us-west-2": "520713654638"
10891079
},
1090-
"repository": "sagemaker-tensorflow-serving-eia",
1091-
"py_versions": []
1080+
"repository": "sagemaker-tensorflow-serving-eia"
10921081
},
10931082
"1.13.0": {
10941083
"registries": {
@@ -1115,8 +1104,7 @@
11151104
"us-west-1": "520713654638",
11161105
"us-west-2": "520713654638"
11171106
},
1118-
"repository": "sagemaker-tensorflow-serving-eia",
1119-
"py_versions": []
1107+
"repository": "sagemaker-tensorflow-serving-eia"
11201108
},
11211109
"1.14.0": {
11221110
"registries": {
@@ -1143,8 +1131,7 @@
11431131
"us-west-1": "763104351884",
11441132
"us-west-2": "763104351884"
11451133
},
1146-
"repository": "tensorflow-inference-eia",
1147-
"py_versions": []
1134+
"repository": "tensorflow-inference-eia"
11481135
},
11491136
"1.15.0": {
11501137
"registries": {
@@ -1171,8 +1158,7 @@
11711158
"us-west-1": "763104351884",
11721159
"us-west-2": "763104351884"
11731160
},
1174-
"repository": "tensorflow-inference-eia",
1175-
"py_versions": []
1161+
"repository": "tensorflow-inference-eia"
11761162
},
11771163
"2.0.0": {
11781164
"registries": {
@@ -1199,8 +1185,7 @@
11991185
"us-west-1": "763104351884",
12001186
"us-west-2": "763104351884"
12011187
},
1202-
"repository": "tensorflow-inference-eia",
1203-
"py_versions": []
1188+
"repository": "tensorflow-inference-eia"
12041189
}
12051190
}
12061191
}

src/sagemaker/image_uris.py

+38-9
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@ def retrieve(
3636
"""Retrieves the ECR URI for the Docker image matching the given arguments.
3737
3838
Args:
39-
framework (str): The name of the framework.
39+
framework (str): The name of the framework or algorithm.
4040
region (str): The AWS region.
41-
version (str): The framework version. This is required if there is
42-
more than one supported version for the given framework.
41+
version (str): The framework or algorithm version. This is required if there is
42+
more than one supported version for the given framework or algorithm.
4343
py_version (str): The Python version. This is required if there is
4444
more than one supported Python version for the given framework version.
4545
instance_type (str): The SageMaker instance type. For supported types, see
@@ -58,7 +58,9 @@ def retrieve(
5858
ValueError: If the combination of arguments specified is not supported.
5959
"""
6060
config = _config_for_framework_and_scope(framework, image_scope, accelerator_type)
61-
version_config = config["versions"][_version_for_config(version, config, framework)]
61+
62+
version = _validate_version_and_set_if_needed(version, config, framework)
63+
version_config = config["versions"][_version_for_config(version, config)]
6264

6365
py_version = _validate_py_version_and_set_if_needed(py_version, version_config)
6466
version_config = version_config.get(py_version) or version_config
@@ -67,7 +69,7 @@ def retrieve(
6769
hostname = utils._botocore_resolver().construct_endpoint("ecr", region)["hostname"]
6870

6971
repo = version_config["repository"]
70-
tag = _format_tag(version, _processor(instance_type, config["processors"]), py_version)
72+
tag = _format_tag(version, _processor(instance_type, config.get("processors")), py_version)
7173

7274
return ECR_URI_TEMPLATE.format(registry=registry, hostname=hostname, repository=repo, tag=tag)
7375

@@ -94,13 +96,33 @@ def config_for_framework(framework):
9496
return json.load(f)
9597

9698

97-
def _version_for_config(version, config, framework):
99+
def _validate_version_and_set_if_needed(version, config, framework):
100+
"""Checks if the framework/algorithm version is one of the supported versions."""
101+
available_versions = list(config["versions"].keys())
102+
103+
if len(available_versions) == 1:
104+
log_message = "Defaulting to the only supported framework/algorithm version: {}.".format(
105+
available_versions[0]
106+
)
107+
if version and version != available_versions[0]:
108+
logger.warning("%s Ignoring framework/algorithm version: %s.", log_message, version)
109+
elif not version:
110+
logger.info(log_message)
111+
112+
return available_versions[0]
113+
114+
available_versions += list(config.get("version_aliases", {}).keys())
115+
_validate_arg("{} version".format(framework), version, available_versions)
116+
117+
return version
118+
119+
120+
def _version_for_config(version, config):
98121
"""Returns the version string for retrieving a framework version's specific config."""
99122
if "version_aliases" in config:
100123
if version in config["version_aliases"].keys():
101124
return config["version_aliases"][version]
102125

103-
_validate_arg("{} version".format(framework), version, config["versions"].keys())
104126
return version
105127

106128

@@ -112,6 +134,10 @@ def _registry_from_region(region, registry_dict):
112134

113135
def _processor(instance_type, available_processors):
114136
"""Returns the processor type for the given instance type."""
137+
if not available_processors:
138+
logger.info("Ignoring unnecessary instance type: %s.", instance_type)
139+
return None
140+
115141
if instance_type.startswith("local"):
116142
processor = "cpu" if instance_type == "local" else "gpu"
117143
elif not instance_type.startswith("ml."):
@@ -129,9 +155,12 @@ def _processor(instance_type, available_processors):
129155

130156
def _validate_py_version_and_set_if_needed(py_version, version_config):
131157
"""Checks if the Python version is one of the supported versions."""
132-
available_versions = version_config.get("py_versions", version_config.keys())
158+
if "repository" in version_config:
159+
available_versions = version_config.get("py_versions")
160+
else:
161+
available_versions = list(version_config.keys())
133162

134-
if len(available_versions) == 0:
163+
if not available_versions:
135164
if py_version:
136165
logger.info("Ignoring unnecessary Python version: %s.", py_version)
137166
return None
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
ALTERNATE_DOMAINS = {
16+
"cn-north-1": "amazonaws.com.cn",
17+
"cn-northwest-1": "amazonaws.com.cn",
18+
"us-iso-east-1": "c2s.ic.gov",
19+
}
20+
DOMAIN = "amazonaws.com"
21+
IMAGE_URI_FORMAT = "{}.dkr.ecr.{}.{}/{}:{}"
22+
REGION = "us-west-2"
23+
24+
25+
def framework_uri(repo, fw_version, account, py_version=None, processor="cpu", region=REGION):
26+
domain = ALTERNATE_DOMAINS.get(region, DOMAIN)
27+
tag = "{}-{}".format(fw_version, processor)
28+
if py_version:
29+
tag = "-".join((tag, py_version))
30+
31+
return IMAGE_URI_FORMAT.format(account, region, domain, repo, tag)
32+
33+
34+
def algo_uri(algo, account, region):
35+
domain = ALTERNATE_DOMAINS.get(region, DOMAIN)
36+
return IMAGE_URI_FORMAT.format(account, region, domain, algo, 1)

0 commit comments

Comments
 (0)