From 59a3f191e83bfc9ac268faa78e494a7554abfd8b Mon Sep 17 00:00:00 2001 From: Lauren Yu <6631887+laurenyu@users.noreply.github.com> Date: Tue, 14 Jul 2020 15:01:31 -0700 Subject: [PATCH 1/2] change: add XGBoost support to image_uris.retrieve() --- src/sagemaker/image_uri_config/xgboost.json | 122 ++++++++++++++++++ src/sagemaker/image_uris.py | 14 +- tests/conftest.py | 10 +- .../sagemaker/image_uris/expected_uris.py | 4 +- tests/unit/sagemaker/image_uris/test_algos.py | 13 +- .../sagemaker/image_uris/test_retrieve.py | 39 ++++++ .../unit/sagemaker/image_uris/test_xgboost.py | 105 +++++++++++++++ tests/unit/test_xgboost.py | 86 ++++++------ 8 files changed, 335 insertions(+), 58 deletions(-) create mode 100644 src/sagemaker/image_uri_config/xgboost.json create mode 100644 tests/unit/sagemaker/image_uris/test_xgboost.py diff --git a/src/sagemaker/image_uri_config/xgboost.json b/src/sagemaker/image_uri_config/xgboost.json new file mode 100644 index 0000000000..246f80cbd0 --- /dev/null +++ b/src/sagemaker/image_uri_config/xgboost.json @@ -0,0 +1,122 @@ +{ + "scope": ["inference", "training"], + "version_aliases": { + "latest": "1" + }, + "versions": { + "1": { + "registries": { + "ap-east-1": "286214385809", + "ap-northeast-1": "501404015308", + "ap-northeast-2": "306986355934", + "ap-south-1": "991648021394", + "ap-southeast-1": "475088953585", + "ap-southeast-2": "544295431143", + "ca-central-1": "469771592824", + "cn-north-1": "390948362332", + "cn-northwest-1": "387376663083", + "eu-central-1": "813361260812", + "eu-north-1": "669576153137", + "eu-west-1": "685385470294", + "eu-west-2": "644912444149", + "eu-west-3": "749696950732", + "me-south-1": "249704162688", + "sa-east-1": "855470959533", + "us-east-1": "811284229777", + "us-east-2": "825641698319", + "us-gov-west-1": "226302683700", + "us-iso-east-1": "490574956308", + "us-west-1": "632365934929", + "us-west-2": "433757028032" + }, + "repository": "xgboost" + }, + "0.90-1": { + "processors": ["cpu"], + "py_versions": ["py3"], + "registries": { + "ap-east-1": "651117190479", + "ap-northeast-1": "354813040037", + "ap-northeast-2": "366743142698", + "ap-south-1": "720646828776", + "ap-southeast-1": "121021644041", + "ap-southeast-2": "783357654285", + "ca-central-1": "341280168497", + "cn-north-1": "450853457545", + "cn-northwest-1": "451049120500", + "eu-central-1": "492215442770", + "eu-north-1": "662702820516", + "eu-west-1": "141502667606", + "eu-west-2": "764974769150", + "eu-west-3": "659782779980", + "me-south-1": "801668240914", + "sa-east-1": "737474898029", + "us-east-1": "683313688378", + "us-east-2": "257758044811", + "us-gov-west-1": "414596584902", + "us-iso-east-1": "833128469047", + "us-west-1": "746614075791", + "us-west-2": "246618743249" + }, + "repository": "sagemaker-xgboost" + }, + "0.90-2": { + "processors": ["cpu"], + "py_versions": ["py3"], + "registries": { + "ap-east-1": "651117190479", + "ap-northeast-1": "354813040037", + "ap-northeast-2": "366743142698", + "ap-south-1": "720646828776", + "ap-southeast-1": "121021644041", + "ap-southeast-2": "783357654285", + "ca-central-1": "341280168497", + "cn-north-1": "450853457545", + "cn-northwest-1": "451049120500", + "eu-central-1": "492215442770", + "eu-north-1": "662702820516", + "eu-west-1": "141502667606", + "eu-west-2": "764974769150", + "eu-west-3": "659782779980", + "me-south-1": "801668240914", + "sa-east-1": "737474898029", + "us-east-1": "683313688378", + "us-east-2": "257758044811", + "us-gov-west-1": "414596584902", + "us-iso-east-1": "833128469047", + "us-west-1": "746614075791", + "us-west-2": "246618743249" + }, + "repository": "sagemaker-xgboost" + }, + "1.0-1": { + "processors": ["cpu"], + "py_versions": ["py3"], + "registries": { + "ap-east-1": "651117190479", + "ap-northeast-1": "354813040037", + "ap-northeast-2": "366743142698", + "ap-south-1": "720646828776", + "ap-southeast-1": "121021644041", + "ap-southeast-2": "783357654285", + "ca-central-1": "341280168497", + "cn-north-1": "450853457545", + "cn-northwest-1": "451049120500", + "eu-central-1": "492215442770", + "eu-north-1": "662702820516", + "eu-west-1": "141502667606", + "eu-west-2": "764974769150", + "eu-west-3": "659782779980", + "me-south-1": "801668240914", + "sa-east-1": "737474898029", + "us-east-1": "683313688378", + "us-east-2": "257758044811", + "us-gov-west-1": "414596584902", + "us-iso-east-1": "833128469047", + "us-west-1": "746614075791", + "us-west-2": "246618743249" + }, + "repository": "sagemaker-xgboost" + } + } +} diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index 773dda7039..ebb234bcb0 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -68,8 +68,12 @@ def retrieve( registry = _registry_from_region(region, version_config["registries"]) hostname = utils._botocore_resolver().construct_endpoint("ecr", region)["hostname"] + processor = _processor( + instance_type, config.get("processors") or version_config.get("processors") + ) + tag = _format_tag(version, processor, py_version) + repo = version_config["repository"] - tag = _format_tag(version, _processor(instance_type, config.get("processors")), py_version) return ECR_URI_TEMPLATE.format(registry=registry, hostname=hostname, repository=repo, tag=tag) @@ -138,11 +142,17 @@ def _processor(instance_type, available_processors): logger.info("Ignoring unnecessary instance type: %s.", instance_type) return None + if not instance_type: + raise ValueError( + "Empty SageMaker instance type. For options, see: " + "https://aws.amazon.com/sagemaker/pricing/instance-types" + ) + if instance_type.startswith("local"): processor = "cpu" if instance_type == "local" else "gpu" elif not instance_type.startswith("ml."): raise ValueError( - "Invalid SageMaker instance type: {}. See: " + "Invalid SageMaker instance type: {}. For options, see: " "https://aws.amazon.com/sagemaker/pricing/instance-types".format(instance_type) ) else: diff --git a/tests/conftest.py b/tests/conftest.py index 260a07204f..5c2b83cea9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -159,9 +159,11 @@ def sklearn_version(request): return request.param -@pytest.fixture(scope="module", params=["0.90-1"]) -def xgboost_version(request): - return request.param +@pytest.fixture(scope="module") +def xgboost_framework_version(xgboost_version): + if xgboost_version in ("1", "latest"): + pytest.skip("Skipping XGBoost algorithm version.") + return xgboost_version @pytest.fixture(scope="module") @@ -365,7 +367,7 @@ def pytest_generate_tests(metafunc): def _generate_all_framework_version_fixtures(metafunc): - for fw in ("chainer", "tensorflow"): + for fw in ("chainer", "tensorflow", "xgboost"): config = image_uris.config_for_framework(fw) if "scope" in config: _parametrize_framework_version_fixtures(metafunc, fw, config) diff --git a/tests/unit/sagemaker/image_uris/expected_uris.py b/tests/unit/sagemaker/image_uris/expected_uris.py index d0de1f7b45..55347ba6e9 100644 --- a/tests/unit/sagemaker/image_uris/expected_uris.py +++ b/tests/unit/sagemaker/image_uris/expected_uris.py @@ -31,6 +31,6 @@ def framework_uri(repo, fw_version, account, py_version=None, processor="cpu", r return IMAGE_URI_FORMAT.format(account, region, domain, repo, tag) -def algo_uri(algo, account, region): +def algo_uri(algo, account, region, version=1): domain = ALTERNATE_DOMAINS.get(region, DOMAIN) - return IMAGE_URI_FORMAT.format(account, region, domain, algo, 1) + return IMAGE_URI_FORMAT.format(account, region, domain, algo, version) diff --git a/tests/unit/sagemaker/image_uris/test_algos.py b/tests/unit/sagemaker/image_uris/test_algos.py index 0d59dc0490..a38551f27a 100644 --- a/tests/unit/sagemaker/image_uris/test_algos.py +++ b/tests/unit/sagemaker/image_uris/test_algos.py @@ -12,10 +12,8 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import -import boto3 - from sagemaker import image_uris -from tests.unit.sagemaker.image_uris import expected_uris +from tests.unit.sagemaker.image_uris import expected_uris, regions ALGO_REGIONS_AND_ACCOUNTS = ( { @@ -60,13 +58,6 @@ IMAGE_URI_FORMAT = "{}.dkr.ecr.{}.{}/{}:1" -def _regions(): - boto_session = boto3.Session() - for partition in boto_session.get_available_partitions(): - for region in boto_session.get_available_regions("sagemaker", partition_name=partition): - yield region - - def _accounts_for_algo(algo): for algo_account_dict in ALGO_REGIONS_AND_ACCOUNTS: if algo in algo_account_dict["algorithms"]: @@ -79,7 +70,7 @@ def test_factorization_machines(): algo = "factorization-machines" accounts = _accounts_for_algo(algo) - for region in _regions(): + for region in regions.regions(): for scope in ("training", "inference"): uri = image_uris.retrieve(algo, region, image_scope=scope) assert expected_uris.algo_uri(algo, accounts[region], region) == uri diff --git a/tests/unit/sagemaker/image_uris/test_retrieve.py b/tests/unit/sagemaker/image_uris/test_retrieve.py index 3b405ae7b8..c8f10249ac 100644 --- a/tests/unit/sagemaker/image_uris/test_retrieve.py +++ b/tests/unit/sagemaker/image_uris/test_retrieve.py @@ -374,6 +374,34 @@ def test_retrieve_processor_type(config_for_framework): assert "123412341234.dkr.ecr.us-west-2.amazonaws.com/dummy:1.0.0-gpu-py3" == uri +@patch("sagemaker.image_uris.config_for_framework") +def test_retrieve_processor_type_from_version_specific_processor_config(config_for_framework): + config = copy.deepcopy(BASE_CONFIG) + del config["processors"] + config["versions"]["1.0.0"]["processors"] = ["cpu"] + config_for_framework.return_value = config + + uri = image_uris.retrieve( + framework="useless-string", + version="1.0.0", + py_version="py3", + instance_type="ml.c4.xlarge", + region="us-west-2", + image_scope="training", + ) + assert "123412341234.dkr.ecr.us-west-2.amazonaws.com/dummy:1.0.0-cpu-py3" == uri + + uri = image_uris.retrieve( + framework="useless-string", + version="1.1.0", + py_version="py3", + instance_type="ml.c4.xlarge", + region="us-west-2", + image_scope="training", + ) + assert "123412341234.dkr.ecr.us-west-2.amazonaws.com/dummy:1.1.0-py3" == uri + + @patch("sagemaker.image_uris.config_for_framework", return_value=BASE_CONFIG) def test_retrieve_unsupported_processor_type(config_for_framework): with pytest.raises(ValueError) as e: @@ -388,6 +416,17 @@ def test_retrieve_unsupported_processor_type(config_for_framework): assert "Invalid SageMaker instance type: not-an-instance-type." in str(e.value) + with pytest.raises(ValueError) as e: + image_uris.retrieve( + framework="useless-string", + version="1.0.0", + py_version="py3", + region="us-west-2", + image_scope="training", + ) + + assert "Empty SageMaker instance type." in str(e.value) + config = copy.deepcopy(BASE_CONFIG) config["processors"] = ["cpu"] config_for_framework.return_value = config diff --git a/tests/unit/sagemaker/image_uris/test_xgboost.py b/tests/unit/sagemaker/image_uris/test_xgboost.py new file mode 100644 index 0000000000..8d1e87d324 --- /dev/null +++ b/tests/unit/sagemaker/image_uris/test_xgboost.py @@ -0,0 +1,105 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest + +from sagemaker import image_uris +from tests.unit.sagemaker.image_uris import expected_uris, regions + +ALGO_REGISTRIES = { + "ap-east-1": "286214385809", + "ap-northeast-1": "501404015308", + "ap-northeast-2": "306986355934", + "ap-south-1": "991648021394", + "ap-southeast-1": "475088953585", + "ap-southeast-2": "544295431143", + "ca-central-1": "469771592824", + "cn-north-1": "390948362332", + "cn-northwest-1": "387376663083", + "eu-central-1": "813361260812", + "eu-north-1": "669576153137", + "eu-west-1": "685385470294", + "eu-west-2": "644912444149", + "eu-west-3": "749696950732", + "me-south-1": "249704162688", + "sa-east-1": "855470959533", + "us-east-1": "811284229777", + "us-east-2": "825641698319", + "us-gov-west-1": "226302683700", + "us-iso-east-1": "490574956308", + "us-west-1": "632365934929", + "us-west-2": "433757028032", +} +ALGO_VERSIONS = ("1", "latest") + +FRAMEWORK_REGISTRIES = { + "ap-east-1": "651117190479", + "ap-northeast-1": "354813040037", + "ap-northeast-2": "366743142698", + "ap-south-1": "720646828776", + "ap-southeast-1": "121021644041", + "ap-southeast-2": "783357654285", + "ca-central-1": "341280168497", + "cn-north-1": "450853457545", + "cn-northwest-1": "451049120500", + "eu-central-1": "492215442770", + "eu-north-1": "662702820516", + "eu-west-1": "141502667606", + "eu-west-2": "764974769150", + "eu-west-3": "659782779980", + "me-south-1": "801668240914", + "sa-east-1": "737474898029", + "us-east-1": "683313688378", + "us-east-2": "257758044811", + "us-gov-west-1": "414596584902", + "us-iso-east-1": "833128469047", + "us-west-1": "746614075791", + "us-west-2": "246618743249", +} + + +def test_xgboost_framework(xgboost_framework_version): + for region in regions.regions(): + for scope in ("training", "inference"): + uri = image_uris.retrieve( + framework="xgboost", + region=region, + version=xgboost_framework_version, + py_version="py3", + instance_type="ml.c4.xlarge", + image_scope=scope, + ) + + expected = expected_uris.framework_uri( + "sagemaker-xgboost", + xgboost_framework_version, + FRAMEWORK_REGISTRIES[region], + py_version="py3", + region=region, + ) + assert expected == uri + + +@pytest.mark.parametrize("xgboost_algo_version", ("1", "latest")) +def test_xgboost_algo(xgboost_algo_version): + for region in regions.regions(): + for scope in ("training", "inference"): + uri = image_uris.retrieve( + framework="xgboost", region=region, version=xgboost_algo_version, image_scope=scope, + ) + + expected = expected_uris.algo_uri( + "xgboost", ALGO_REGISTRIES[region], region, version=xgboost_algo_version + ) + assert expected == uri diff --git a/tests/unit/test_xgboost.py b/tests/unit/test_xgboost.py index 910c895ea6..452ccbb262 100644 --- a/tests/unit/test_xgboost.py +++ b/tests/unit/test_xgboost.py @@ -149,7 +149,7 @@ def _create_train_job(version, instance_count=1): } -def test_train_image(sagemaker_session, xgboost_version): +def test_train_image(sagemaker_session, xgboost_framework_version): container_log_level = '"logging.INFO"' source_dir = "s3://mybucket/source" xgboost = XGBoost( @@ -158,7 +158,7 @@ def test_train_image(sagemaker_session, xgboost_version): sagemaker_session=sagemaker_session, instance_type=INSTANCE_TYPE, instance_count=1, - framework_version=xgboost_version, + framework_version=xgboost_framework_version, container_log_level=container_log_level, py_version=PYTHON_VERSION, base_job_name="job", @@ -168,11 +168,13 @@ def test_train_image(sagemaker_session, xgboost_version): train_image = xgboost.train_image() assert ( train_image - == "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:0.90-1-cpu-py3" + == "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:{}-cpu-py3".format( + xgboost_framework_version + ) ) -def test_create_model(sagemaker_session, xgboost_full_version): +def test_create_model(sagemaker_session, xgboost_framework_version): source_dir = "s3://mybucket/source" xgboost_model = XGBoostModel( @@ -180,15 +182,15 @@ def test_create_model(sagemaker_session, xgboost_full_version): role=ROLE, sagemaker_session=sagemaker_session, entry_point=SCRIPT_PATH, - framework_version=xgboost_full_version, + framework_version=xgboost_framework_version, ) - default_image_uri = _get_full_cpu_image_uri(xgboost_full_version) + default_image_uri = _get_full_cpu_image_uri(xgboost_framework_version) model_values = xgboost_model.prepare_container_def(CPU) assert model_values["Image"] == default_image_uri @patch("sagemaker.estimator.name_from_base") -def test_create_model_from_estimator(name_from_base, sagemaker_session, xgboost_version): +def test_create_model_from_estimator(name_from_base, sagemaker_session, xgboost_framework_version): container_log_level = '"logging.INFO"' source_dir = "s3://mybucket/source" base_job_name = "job" @@ -199,7 +201,7 @@ def test_create_model_from_estimator(name_from_base, sagemaker_session, xgboost_ sagemaker_session=sagemaker_session, instance_type=INSTANCE_TYPE, instance_count=1, - framework_version=xgboost_version, + framework_version=xgboost_framework_version, container_log_level=container_log_level, py_version=PYTHON_VERSION, base_job_name=base_job_name, @@ -213,7 +215,7 @@ def test_create_model_from_estimator(name_from_base, sagemaker_session, xgboost_ model = xgboost.create_model() assert model.sagemaker_session == sagemaker_session - assert model.framework_version == xgboost_version + assert model.framework_version == xgboost_framework_version assert model.py_version == xgboost.py_version assert model.entry_point == SCRIPT_PATH assert model.role == ROLE @@ -225,14 +227,14 @@ def test_create_model_from_estimator(name_from_base, sagemaker_session, xgboost_ name_from_base.assert_called_with(base_job_name) -def test_create_model_with_optional_params(sagemaker_session, xgboost_full_version): +def test_create_model_with_optional_params(sagemaker_session, xgboost_framework_version): container_log_level = '"logging.INFO"' source_dir = "s3://mybucket/source" enable_cloudwatch_metrics = "true" xgboost = XGBoost( entry_point=SCRIPT_PATH, role=ROLE, - framework_version=xgboost_full_version, + framework_version=xgboost_framework_version, sagemaker_session=sagemaker_session, instance_type=INSTANCE_TYPE, instance_count=1, @@ -273,14 +275,14 @@ def test_create_model_with_optional_params(sagemaker_session, xgboost_full_versi assert model.name == model_name -def test_create_model_with_custom_image(sagemaker_session, xgboost_full_version): +def test_create_model_with_custom_image(sagemaker_session, xgboost_framework_version): container_log_level = '"logging.INFO"' source_dir = "s3://mybucket/source" custom_image = "ubuntu:latest" xgboost = XGBoost( entry_point=SCRIPT_PATH, role=ROLE, - framework_version=xgboost_full_version, + framework_version=xgboost_framework_version, sagemaker_session=sagemaker_session, instance_type=INSTANCE_TYPE, instance_count=1, @@ -298,7 +300,7 @@ def test_create_model_with_custom_image(sagemaker_session, xgboost_full_version) @patch("time.strftime", return_value=TIMESTAMP) -def test_xgboost(strftime, sagemaker_session, xgboost_version): +def test_xgboost(strftime, sagemaker_session, xgboost_framework_version): xgboost = XGBoost( entry_point=SCRIPT_PATH, role=ROLE, @@ -306,7 +308,7 @@ def test_xgboost(strftime, sagemaker_session, xgboost_version): instance_type=INSTANCE_TYPE, instance_count=1, py_version=PYTHON_VERSION, - framework_version=xgboost_version, + framework_version=xgboost_framework_version, ) inputs = "s3://mybucket/train" @@ -318,7 +320,7 @@ def test_xgboost(strftime, sagemaker_session, xgboost_version): boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] assert boto_call_names == ["resource"] - expected_train_args = _create_train_job(xgboost_version) + expected_train_args = _create_train_job(xgboost_framework_version) expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs expected_train_args["experiment_config"] = EXPERIMENT_CONFIG @@ -338,7 +340,7 @@ def test_xgboost(strftime, sagemaker_session, xgboost_version): "SAGEMAKER_REGION": "us-west-2", "SAGEMAKER_CONTAINER_LOG_LEVEL": "20", }, - "Image": expected_image_base.format(xgboost_version, PYTHON_VERSION), + "Image": expected_image_base.format(xgboost_framework_version, PYTHON_VERSION), "ModelDataUrl": "s3://m/m.tar.gz", } == model.prepare_container_def(CPU) @@ -348,7 +350,7 @@ def test_xgboost(strftime, sagemaker_session, xgboost_version): @patch("time.strftime", return_value=TIMESTAMP) -def test_distributed_training(strftime, sagemaker_session, xgboost_version): +def test_distributed_training(strftime, sagemaker_session, xgboost_framework_version): xgboost = XGBoost( entry_point=SCRIPT_PATH, role=ROLE, @@ -356,7 +358,7 @@ def test_distributed_training(strftime, sagemaker_session, xgboost_version): instance_count=DIST_INSTANCE_COUNT, instance_type=INSTANCE_TYPE, py_version=PYTHON_VERSION, - framework_version=xgboost_version, + framework_version=xgboost_framework_version, ) inputs = "s3://mybucket/train" @@ -368,7 +370,7 @@ def test_distributed_training(strftime, sagemaker_session, xgboost_version): boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] assert boto_call_names == ["resource"] - expected_train_args = _create_train_job(xgboost_version, DIST_INSTANCE_COUNT) + expected_train_args = _create_train_job(xgboost_framework_version, DIST_INSTANCE_COUNT) expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs actual_train_args = sagemaker_session.method_calls[0][2] @@ -387,7 +389,7 @@ def test_distributed_training(strftime, sagemaker_session, xgboost_version): "SAGEMAKER_REGION": "us-west-2", "SAGEMAKER_CONTAINER_LOG_LEVEL": "20", }, - "Image": expected_image_base.format(xgboost_version, PYTHON_VERSION), + "Image": expected_image_base.format(xgboost_framework_version, PYTHON_VERSION), "ModelDataUrl": "s3://m/m.tar.gz", } == model.prepare_container_def(CPU) @@ -396,11 +398,11 @@ def test_distributed_training(strftime, sagemaker_session, xgboost_version): assert isinstance(predictor, XGBoostPredictor) -def test_model(sagemaker_session, xgboost_full_version): +def test_model(sagemaker_session, xgboost_framework_version): model = XGBoostModel( "s3://some/data.tar.gz", role=ROLE, - framework_version=xgboost_full_version, + framework_version=xgboost_framework_version, entry_point=SCRIPT_PATH, sagemaker_session=sagemaker_session, ) @@ -408,34 +410,40 @@ def test_model(sagemaker_session, xgboost_full_version): assert isinstance(predictor, XGBoostPredictor) -def test_train_image_default(sagemaker_session, xgboost_full_version): +def test_train_image_default(sagemaker_session, xgboost_framework_version): xgboost = XGBoost( entry_point=SCRIPT_PATH, role=ROLE, - framework_version=xgboost_full_version, + framework_version=xgboost_framework_version, sagemaker_session=sagemaker_session, instance_type=INSTANCE_TYPE, instance_count=1, py_version=PYTHON_VERSION, ) - assert _get_full_cpu_image_uri(xgboost_full_version) in xgboost.train_image() + assert _get_full_cpu_image_uri(xgboost_framework_version) in xgboost.train_image() -def test_train_image_cpu_instances(sagemaker_session, xgboost_version): - xgboost = _xgboost_estimator(sagemaker_session, xgboost_version, instance_type="ml.c2.2xlarge") - assert xgboost.train_image() == _get_full_cpu_image_uri(xgboost_version) +def test_train_image_cpu_instances(sagemaker_session, xgboost_framework_version): + xgboost = _xgboost_estimator( + sagemaker_session, xgboost_framework_version, instance_type="ml.c2.2xlarge" + ) + assert xgboost.train_image() == _get_full_cpu_image_uri(xgboost_framework_version) - xgboost = _xgboost_estimator(sagemaker_session, xgboost_version, instance_type="ml.c4.2xlarge") - assert xgboost.train_image() == _get_full_cpu_image_uri(xgboost_version) + xgboost = _xgboost_estimator( + sagemaker_session, xgboost_framework_version, instance_type="ml.c4.2xlarge" + ) + assert xgboost.train_image() == _get_full_cpu_image_uri(xgboost_framework_version) - xgboost = _xgboost_estimator(sagemaker_session, xgboost_version, instance_type="ml.m16") - assert xgboost.train_image() == _get_full_cpu_image_uri(xgboost_version) + xgboost = _xgboost_estimator( + sagemaker_session, xgboost_framework_version, instance_type="ml.m16" + ) + assert xgboost.train_image() == _get_full_cpu_image_uri(xgboost_framework_version) -def test_attach(sagemaker_session, xgboost_version): +def test_attach(sagemaker_session, xgboost_framework_version): training_image = "1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:{}-cpu-{}".format( - xgboost_version, PYTHON_VERSION + xgboost_framework_version, PYTHON_VERSION ) returned_job_description = { "AlgorithmSpecification": {"TrainingInputMode": "File", "TrainingImage": training_image}, @@ -470,7 +478,7 @@ def test_attach(sagemaker_session, xgboost_version): assert estimator._current_job_name == "neo" assert estimator.latest_training_job.job_name == "neo" assert estimator.py_version == PYTHON_VERSION - assert estimator.framework_version == xgboost_version + assert estimator.framework_version == xgboost_framework_version assert estimator.role == "arn:aws:iam::366:role/SageMakerRole" assert estimator.instance_count == 1 assert estimator.max_run == 24 * 60 * 60 @@ -556,12 +564,12 @@ def test_attach_custom_image(sagemaker_session): assert "expected string" in str(error) -def test_py2_xgboost_attribute_error(sagemaker_session, xgboost_full_version): +def test_py2_xgboost_attribute_error(sagemaker_session, xgboost_framework_version): with pytest.raises(AttributeError) as error1: XGBoost( entry_point=SCRIPT_PATH, role=ROLE, - framework_version=xgboost_full_version, + framework_version=xgboost_framework_version, sagemaker_session=sagemaker_session, instance_type=INSTANCE_TYPE, instance_count=1, @@ -574,7 +582,7 @@ def test_py2_xgboost_attribute_error(sagemaker_session, xgboost_full_version): role=ROLE, sagemaker_session=sagemaker_session, entry_point=SCRIPT_PATH, - framework_version=xgboost_full_version, + framework_version=xgboost_framework_version, py_version="py2", ) From 8d6fbb27ea97f459e83e9f0ef32c2e143231b2a6 Mon Sep 17 00:00:00 2001 From: Lauren Yu <6631887+laurenyu@users.noreply.github.com> Date: Wed, 15 Jul 2020 14:49:16 -0700 Subject: [PATCH 2/2] add new file --- tests/unit/sagemaker/image_uris/regions.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 tests/unit/sagemaker/image_uris/regions.py diff --git a/tests/unit/sagemaker/image_uris/regions.py b/tests/unit/sagemaker/image_uris/regions.py new file mode 100644 index 0000000000..0e59e98c24 --- /dev/null +++ b/tests/unit/sagemaker/image_uris/regions.py @@ -0,0 +1,22 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import boto3 + + +def regions(): + boto_session = boto3.Session() + for partition in boto_session.get_available_partitions(): + for region in boto_session.get_available_regions("sagemaker", partition_name=partition): + yield region