Skip to content

Commit 3a90f94

Browse files
authored
change: add XGBoost support to image_uris.retrieve() (#1714)
1 parent 211f4e5 commit 3a90f94

File tree

9 files changed

+357
-58
lines changed

9 files changed

+357
-58
lines changed
+122
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
{
2+
"scope": ["inference", "training"],
3+
"version_aliases": {
4+
"latest": "1"
5+
},
6+
"versions": {
7+
"1": {
8+
"registries": {
9+
"ap-east-1": "286214385809",
10+
"ap-northeast-1": "501404015308",
11+
"ap-northeast-2": "306986355934",
12+
"ap-south-1": "991648021394",
13+
"ap-southeast-1": "475088953585",
14+
"ap-southeast-2": "544295431143",
15+
"ca-central-1": "469771592824",
16+
"cn-north-1": "390948362332",
17+
"cn-northwest-1": "387376663083",
18+
"eu-central-1": "813361260812",
19+
"eu-north-1": "669576153137",
20+
"eu-west-1": "685385470294",
21+
"eu-west-2": "644912444149",
22+
"eu-west-3": "749696950732",
23+
"me-south-1": "249704162688",
24+
"sa-east-1": "855470959533",
25+
"us-east-1": "811284229777",
26+
"us-east-2": "825641698319",
27+
"us-gov-west-1": "226302683700",
28+
"us-iso-east-1": "490574956308",
29+
"us-west-1": "632365934929",
30+
"us-west-2": "433757028032"
31+
},
32+
"repository": "xgboost"
33+
},
34+
"0.90-1": {
35+
"processors": ["cpu"],
36+
"py_versions": ["py3"],
37+
"registries": {
38+
"ap-east-1": "651117190479",
39+
"ap-northeast-1": "354813040037",
40+
"ap-northeast-2": "366743142698",
41+
"ap-south-1": "720646828776",
42+
"ap-southeast-1": "121021644041",
43+
"ap-southeast-2": "783357654285",
44+
"ca-central-1": "341280168497",
45+
"cn-north-1": "450853457545",
46+
"cn-northwest-1": "451049120500",
47+
"eu-central-1": "492215442770",
48+
"eu-north-1": "662702820516",
49+
"eu-west-1": "141502667606",
50+
"eu-west-2": "764974769150",
51+
"eu-west-3": "659782779980",
52+
"me-south-1": "801668240914",
53+
"sa-east-1": "737474898029",
54+
"us-east-1": "683313688378",
55+
"us-east-2": "257758044811",
56+
"us-gov-west-1": "414596584902",
57+
"us-iso-east-1": "833128469047",
58+
"us-west-1": "746614075791",
59+
"us-west-2": "246618743249"
60+
},
61+
"repository": "sagemaker-xgboost"
62+
},
63+
"0.90-2": {
64+
"processors": ["cpu"],
65+
"py_versions": ["py3"],
66+
"registries": {
67+
"ap-east-1": "651117190479",
68+
"ap-northeast-1": "354813040037",
69+
"ap-northeast-2": "366743142698",
70+
"ap-south-1": "720646828776",
71+
"ap-southeast-1": "121021644041",
72+
"ap-southeast-2": "783357654285",
73+
"ca-central-1": "341280168497",
74+
"cn-north-1": "450853457545",
75+
"cn-northwest-1": "451049120500",
76+
"eu-central-1": "492215442770",
77+
"eu-north-1": "662702820516",
78+
"eu-west-1": "141502667606",
79+
"eu-west-2": "764974769150",
80+
"eu-west-3": "659782779980",
81+
"me-south-1": "801668240914",
82+
"sa-east-1": "737474898029",
83+
"us-east-1": "683313688378",
84+
"us-east-2": "257758044811",
85+
"us-gov-west-1": "414596584902",
86+
"us-iso-east-1": "833128469047",
87+
"us-west-1": "746614075791",
88+
"us-west-2": "246618743249"
89+
},
90+
"repository": "sagemaker-xgboost"
91+
},
92+
"1.0-1": {
93+
"processors": ["cpu"],
94+
"py_versions": ["py3"],
95+
"registries": {
96+
"ap-east-1": "651117190479",
97+
"ap-northeast-1": "354813040037",
98+
"ap-northeast-2": "366743142698",
99+
"ap-south-1": "720646828776",
100+
"ap-southeast-1": "121021644041",
101+
"ap-southeast-2": "783357654285",
102+
"ca-central-1": "341280168497",
103+
"cn-north-1": "450853457545",
104+
"cn-northwest-1": "451049120500",
105+
"eu-central-1": "492215442770",
106+
"eu-north-1": "662702820516",
107+
"eu-west-1": "141502667606",
108+
"eu-west-2": "764974769150",
109+
"eu-west-3": "659782779980",
110+
"me-south-1": "801668240914",
111+
"sa-east-1": "737474898029",
112+
"us-east-1": "683313688378",
113+
"us-east-2": "257758044811",
114+
"us-gov-west-1": "414596584902",
115+
"us-iso-east-1": "833128469047",
116+
"us-west-1": "746614075791",
117+
"us-west-2": "246618743249"
118+
},
119+
"repository": "sagemaker-xgboost"
120+
}
121+
}
122+
}

src/sagemaker/image_uris.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,12 @@ def retrieve(
6868
registry = _registry_from_region(region, version_config["registries"])
6969
hostname = utils._botocore_resolver().construct_endpoint("ecr", region)["hostname"]
7070

71+
processor = _processor(
72+
instance_type, config.get("processors") or version_config.get("processors")
73+
)
74+
tag = _format_tag(version, processor, py_version)
75+
7176
repo = version_config["repository"]
72-
tag = _format_tag(version, _processor(instance_type, config.get("processors")), py_version)
7377

7478
return ECR_URI_TEMPLATE.format(registry=registry, hostname=hostname, repository=repo, tag=tag)
7579

@@ -138,11 +142,17 @@ def _processor(instance_type, available_processors):
138142
logger.info("Ignoring unnecessary instance type: %s.", instance_type)
139143
return None
140144

145+
if not instance_type:
146+
raise ValueError(
147+
"Empty SageMaker instance type. For options, see: "
148+
"https://aws.amazon.com/sagemaker/pricing/instance-types"
149+
)
150+
141151
if instance_type.startswith("local"):
142152
processor = "cpu" if instance_type == "local" else "gpu"
143153
elif not instance_type.startswith("ml."):
144154
raise ValueError(
145-
"Invalid SageMaker instance type: {}. See: "
155+
"Invalid SageMaker instance type: {}. For options, see: "
146156
"https://aws.amazon.com/sagemaker/pricing/instance-types".format(instance_type)
147157
)
148158
else:

tests/conftest.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,11 @@ def sklearn_version(request):
159159
return request.param
160160

161161

162-
@pytest.fixture(scope="module", params=["0.90-1"])
163-
def xgboost_version(request):
164-
return request.param
162+
@pytest.fixture(scope="module")
163+
def xgboost_framework_version(xgboost_version):
164+
if xgboost_version in ("1", "latest"):
165+
pytest.skip("Skipping XGBoost algorithm version.")
166+
return xgboost_version
165167

166168

167169
@pytest.fixture(scope="module", params=["py2", "py3"])
@@ -351,7 +353,7 @@ def pytest_generate_tests(metafunc):
351353

352354

353355
def _generate_all_framework_version_fixtures(metafunc):
354-
for fw in ("chainer", "tensorflow"):
356+
for fw in ("chainer", "tensorflow", "xgboost"):
355357
config = image_uris.config_for_framework(fw)
356358
if "scope" in config:
357359
_parametrize_framework_version_fixtures(metafunc, fw, config)

tests/unit/sagemaker/image_uris/expected_uris.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,6 @@ def framework_uri(repo, fw_version, account, py_version=None, processor="cpu", r
3131
return IMAGE_URI_FORMAT.format(account, region, domain, repo, tag)
3232

3333

34-
def algo_uri(algo, account, region):
34+
def algo_uri(algo, account, region, version=1):
3535
domain = ALTERNATE_DOMAINS.get(region, DOMAIN)
36-
return IMAGE_URI_FORMAT.format(account, region, domain, algo, 1)
36+
return IMAGE_URI_FORMAT.format(account, region, domain, algo, version)
+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import boto3
16+
17+
18+
def regions():
19+
boto_session = boto3.Session()
20+
for partition in boto_session.get_available_partitions():
21+
for region in boto_session.get_available_regions("sagemaker", partition_name=partition):
22+
yield region

tests/unit/sagemaker/image_uris/test_algos.py

+2-11
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,8 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15-
import boto3
16-
1715
from sagemaker import image_uris
18-
from tests.unit.sagemaker.image_uris import expected_uris
16+
from tests.unit.sagemaker.image_uris import expected_uris, regions
1917

2018
ALGO_REGIONS_AND_ACCOUNTS = (
2119
{
@@ -60,13 +58,6 @@
6058
IMAGE_URI_FORMAT = "{}.dkr.ecr.{}.{}/{}:1"
6159

6260

63-
def _regions():
64-
boto_session = boto3.Session()
65-
for partition in boto_session.get_available_partitions():
66-
for region in boto_session.get_available_regions("sagemaker", partition_name=partition):
67-
yield region
68-
69-
7061
def _accounts_for_algo(algo):
7162
for algo_account_dict in ALGO_REGIONS_AND_ACCOUNTS:
7263
if algo in algo_account_dict["algorithms"]:
@@ -79,7 +70,7 @@ def test_factorization_machines():
7970
algo = "factorization-machines"
8071
accounts = _accounts_for_algo(algo)
8172

82-
for region in _regions():
73+
for region in regions.regions():
8374
for scope in ("training", "inference"):
8475
uri = image_uris.retrieve(algo, region, image_scope=scope)
8576
assert expected_uris.algo_uri(algo, accounts[region], region) == uri

tests/unit/sagemaker/image_uris/test_retrieve.py

+39
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,34 @@ def test_retrieve_processor_type(config_for_framework):
374374
assert "123412341234.dkr.ecr.us-west-2.amazonaws.com/dummy:1.0.0-gpu-py3" == uri
375375

376376

377+
@patch("sagemaker.image_uris.config_for_framework")
378+
def test_retrieve_processor_type_from_version_specific_processor_config(config_for_framework):
379+
config = copy.deepcopy(BASE_CONFIG)
380+
del config["processors"]
381+
config["versions"]["1.0.0"]["processors"] = ["cpu"]
382+
config_for_framework.return_value = config
383+
384+
uri = image_uris.retrieve(
385+
framework="useless-string",
386+
version="1.0.0",
387+
py_version="py3",
388+
instance_type="ml.c4.xlarge",
389+
region="us-west-2",
390+
image_scope="training",
391+
)
392+
assert "123412341234.dkr.ecr.us-west-2.amazonaws.com/dummy:1.0.0-cpu-py3" == uri
393+
394+
uri = image_uris.retrieve(
395+
framework="useless-string",
396+
version="1.1.0",
397+
py_version="py3",
398+
instance_type="ml.c4.xlarge",
399+
region="us-west-2",
400+
image_scope="training",
401+
)
402+
assert "123412341234.dkr.ecr.us-west-2.amazonaws.com/dummy:1.1.0-py3" == uri
403+
404+
377405
@patch("sagemaker.image_uris.config_for_framework", return_value=BASE_CONFIG)
378406
def test_retrieve_unsupported_processor_type(config_for_framework):
379407
with pytest.raises(ValueError) as e:
@@ -388,6 +416,17 @@ def test_retrieve_unsupported_processor_type(config_for_framework):
388416

389417
assert "Invalid SageMaker instance type: not-an-instance-type." in str(e.value)
390418

419+
with pytest.raises(ValueError) as e:
420+
image_uris.retrieve(
421+
framework="useless-string",
422+
version="1.0.0",
423+
py_version="py3",
424+
region="us-west-2",
425+
image_scope="training",
426+
)
427+
428+
assert "Empty SageMaker instance type." in str(e.value)
429+
391430
config = copy.deepcopy(BASE_CONFIG)
392431
config["processors"] = ["cpu"]
393432
config_for_framework.return_value = config

0 commit comments

Comments
 (0)