diff --git a/tests/unit/sagemaker/image_uris/test_algos.py b/tests/unit/sagemaker/image_uris/test_algos.py index 9ece7376f3..fbc81f7829 100644 --- a/tests/unit/sagemaker/image_uris/test_algos.py +++ b/tests/unit/sagemaker/image_uris/test_algos.py @@ -15,7 +15,7 @@ import pytest from sagemaker import image_uris -from tests.unit.sagemaker.image_uris import expected_uris, regions +from tests.unit.sagemaker.image_uris import expected_uris ALGO_NAMES = ( "blazingtext", @@ -33,6 +33,7 @@ "randomcutforest", "semantic-segmentation", "seq2seq", + "lda", ) ALGO_REGIONS_AND_ACCOUNTS = ( { @@ -176,21 +177,6 @@ def _accounts_for_algo(algo): @pytest.mark.parametrize("algo", ALGO_NAMES) def test_algo_uris(algo): accounts = _accounts_for_algo(algo) - - for region in regions.regions(): + for region in accounts: uri = image_uris.retrieve(algo, region) assert expected_uris.algo_uri(algo, accounts[region], region) == uri - - -def test_lda(): - algo = "lda" - accounts = _accounts_for_algo(algo) - - for region in regions.regions(): - if region in accounts: - uri = image_uris.retrieve(algo, region) - assert expected_uris.algo_uri(algo, accounts[region], region) == uri - else: - with pytest.raises(ValueError) as e: - image_uris.retrieve(algo, region) - assert "Unsupported region: {}.".format(region) in str(e.value) diff --git a/tests/unit/sagemaker/image_uris/test_data_wrangler.py b/tests/unit/sagemaker/image_uris/test_data_wrangler.py index d56c98ae21..79f2e83e71 100644 --- a/tests/unit/sagemaker/image_uris/test_data_wrangler.py +++ b/tests/unit/sagemaker/image_uris/test_data_wrangler.py @@ -13,7 +13,7 @@ from __future__ import absolute_import from sagemaker import image_uris -from tests.unit.sagemaker.image_uris import expected_uris, regions +from tests.unit.sagemaker.image_uris import expected_uris DATA_WRANGLER_ACCOUNTS = { "af-south-1": "143210264188", @@ -42,14 +42,12 @@ def test_data_wrangler_ecr_uri(): - for region in regions.regions(): - if region in DATA_WRANGLER_ACCOUNTS.keys(): - actual_uri = image_uris.retrieve("data-wrangler", region=region) - - expected_uri = expected_uris.algo_uri( - "sagemaker-data-wrangler-container", - DATA_WRANGLER_ACCOUNTS[region], - region, - version="1.x", - ) - assert expected_uri == actual_uri + for region in DATA_WRANGLER_ACCOUNTS.keys(): + actual_uri = image_uris.retrieve("data-wrangler", region=region) + expected_uri = expected_uris.algo_uri( + "sagemaker-data-wrangler-container", + DATA_WRANGLER_ACCOUNTS[region], + region, + version="1.x", + ) + assert expected_uri == actual_uri diff --git a/tests/unit/sagemaker/image_uris/test_debugger.py b/tests/unit/sagemaker/image_uris/test_debugger.py index eafdf64f3b..3acf6251aa 100644 --- a/tests/unit/sagemaker/image_uris/test_debugger.py +++ b/tests/unit/sagemaker/image_uris/test_debugger.py @@ -13,13 +13,14 @@ from __future__ import absolute_import from sagemaker import image_uris -from tests.unit.sagemaker.image_uris import expected_uris, regions +from tests.unit.sagemaker.image_uris import expected_uris ACCOUNTS = { "af-south-1": "314341159256", "ap-east-1": "199566480951", "ap-northeast-1": "430734990657", "ap-northeast-2": "578805364391", + "ap-northeast-3": "479947661362", "ap-south-1": "904829902805", "ap-southeast-1": "972752614525", "ap-southeast-2": "184798709955", @@ -43,11 +44,9 @@ def test_debugger(): - for region in regions.regions(): - if region in ACCOUNTS.keys(): - uri = image_uris.retrieve("debugger", region=region) - - expected = expected_uris.algo_uri( - "sagemaker-debugger-rules", ACCOUNTS[region], region, version="latest" - ) - assert expected == uri + for region in ACCOUNTS.keys(): + uri = image_uris.retrieve("debugger", region=region) + expected = expected_uris.algo_uri( + "sagemaker-debugger-rules", ACCOUNTS[region], region, version="latest" + ) + assert expected == uri diff --git a/tests/unit/sagemaker/image_uris/test_neo.py b/tests/unit/sagemaker/image_uris/test_neo.py index 8f16100d32..d630862381 100644 --- a/tests/unit/sagemaker/image_uris/test_neo.py +++ b/tests/unit/sagemaker/image_uris/test_neo.py @@ -15,7 +15,7 @@ import pytest from sagemaker import image_uris -from tests.unit.sagemaker.image_uris import expected_uris, regions +from tests.unit.sagemaker.image_uris import expected_uris NEO_ALGOS = ("image-classification-neo", "xgboost-neo") @@ -24,6 +24,7 @@ "ap-east-1": "110948597952", "ap-northeast-1": "941853720454", "ap-northeast-2": "151534178276", + "ap-northeast-3": "925152966179", "ap-south-1": "763008648453", "ap-southeast-1": "324986816169", "ap-southeast-2": "355873309152", @@ -50,33 +51,21 @@ @pytest.mark.parametrize("algo", NEO_ALGOS) def test_algo_uris(algo): - for region in regions.regions(): - if region in ACCOUNTS: - uri = image_uris.retrieve(algo, region) - expected = expected_uris.algo_uri(algo, ACCOUNTS[region], region, version="latest") - assert expected == uri - else: - with pytest.raises(ValueError) as e: - image_uris.retrieve(algo, region) - assert "Unsupported region: {}.".format(region) in str(e.value) + for region in ACCOUNTS.keys(): + uri = image_uris.retrieve(algo, region) + expected = expected_uris.algo_uri(algo, ACCOUNTS[region], region, version="latest") + assert expected == uri def _test_neo_framework_uris(framework, version): framework_in_config = f"neo-{framework}" framework_in_uri = f"inference-{framework}" - for region in regions.regions(): - if region in ACCOUNTS: - uri = image_uris.retrieve( - framework_in_config, region, instance_type="ml_c5", version=version - ) - assert _expected_framework_uri(framework_in_uri, version, region=region) == uri - else: - with pytest.raises(ValueError) as e: - image_uris.retrieve( - framework_in_config, region, instance_type="ml_c5", version=version - ) - assert "Unsupported region: {}.".format(region) in str(e.value) + for region in ACCOUNTS.keys(): + uri = image_uris.retrieve( + framework_in_config, region, instance_type="ml_c5", version=version + ) + assert _expected_framework_uri(framework_in_uri, version, region=region) == uri uri = image_uris.retrieve( framework_in_config, "us-west-2", instance_type="ml_p2", version=version @@ -97,24 +86,14 @@ def test_neo_pytorch(neo_pytorch_version): def _test_inferentia_framework_uris(framework, version): - for region in regions.regions(): - if region in INFERENTIA_REGIONS: - uri = image_uris.retrieve( - "inferentia-{}".format(framework), region, instance_type="ml_inf1", version=version - ) - expected = _expected_framework_uri( - "neo-{}".format(framework), version, region=region, processor="inf" - ) - assert expected == uri - else: - with pytest.raises(ValueError) as e: - image_uris.retrieve( - "inferentia-{}".format(framework), - region, - instance_type="ml_inf", - version=version, - ) - assert "Unsupported region: {}.".format(region) in str(e.value) + for region in INFERENTIA_REGIONS: + uri = image_uris.retrieve( + "inferentia-{}".format(framework), region, instance_type="ml_inf1", version=version + ) + expected = _expected_framework_uri( + "neo-{}".format(framework), version, region=region, processor="inf" + ) + assert expected == uri def test_inferentia_mxnet(inferentia_mxnet_version): diff --git a/tests/unit/sagemaker/image_uris/test_sklearn.py b/tests/unit/sagemaker/image_uris/test_sklearn.py index f31ba4c4a2..58b668f3a2 100644 --- a/tests/unit/sagemaker/image_uris/test_sklearn.py +++ b/tests/unit/sagemaker/image_uris/test_sklearn.py @@ -15,7 +15,7 @@ import pytest from sagemaker import image_uris -from tests.unit.sagemaker.image_uris import expected_uris, regions +from tests.unit.sagemaker.image_uris import expected_uris ACCOUNTS = { "af-south-1": "510948584623", @@ -47,7 +47,7 @@ def test_valid_uris(sklearn_version): - for region in regions.regions(): + for region in ACCOUNTS.keys(): uri = image_uris.retrieve( "sklearn", region=region, diff --git a/tests/unit/sagemaker/image_uris/test_sparkml.py b/tests/unit/sagemaker/image_uris/test_sparkml.py index c5260178e7..3addf34d33 100644 --- a/tests/unit/sagemaker/image_uris/test_sparkml.py +++ b/tests/unit/sagemaker/image_uris/test_sparkml.py @@ -15,7 +15,7 @@ import pytest from sagemaker import image_uris -from tests.unit.sagemaker.image_uris import expected_uris, regions +from tests.unit.sagemaker.image_uris import expected_uris ACCOUNTS = { "af-south-1": "510948584623", @@ -48,7 +48,7 @@ @pytest.mark.parametrize("version", VERSIONS) def test_sparkml(version): - for region in regions.regions(): + for region in ACCOUNTS.keys(): uri = image_uris.retrieve("sparkml-serving", region=region, version=version) expected = expected_uris.algo_uri( diff --git a/tests/unit/sagemaker/image_uris/test_xgboost.py b/tests/unit/sagemaker/image_uris/test_xgboost.py index 943f0ae259..6673de3bbb 100644 --- a/tests/unit/sagemaker/image_uris/test_xgboost.py +++ b/tests/unit/sagemaker/image_uris/test_xgboost.py @@ -11,11 +11,9 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import - import pytest - from sagemaker import image_uris -from tests.unit.sagemaker.image_uris import expected_uris, regions +from tests.unit.sagemaker.image_uris import expected_uris ALGO_REGISTRIES = { "af-south-1": "455444449433", @@ -53,6 +51,7 @@ "ap-east-1": "651117190479", "ap-northeast-1": "354813040037", "ap-northeast-2": "366743142698", + "ap-northeast-3": "867004704886", "ap-south-1": "720646828776", "ap-southeast-1": "121021644041", "ap-southeast-2": "783357654285", @@ -78,7 +77,7 @@ @pytest.mark.parametrize("xgboost_framework_version", XGBOOST_FRAMEWORK_CPU_GPU_VERSIONS) def test_xgboost_framework(xgboost_framework_version): - for region in regions.regions(): + for region in FRAMEWORK_REGISTRIES.keys(): uri = image_uris.retrieve( framework="xgboost", region=region, @@ -98,7 +97,7 @@ def test_xgboost_framework(xgboost_framework_version): @pytest.mark.parametrize("xgboost_framework_version", XGBOOST_FRAMEWORK_CPU_ONLY_VERSIONS) def test_xgboost_framework_cpu_only(xgboost_framework_version): - for region in regions.regions(): + for region in FRAMEWORK_REGISTRIES.keys(): uri = image_uris.retrieve( framework="xgboost", region=region, @@ -118,7 +117,7 @@ def test_xgboost_framework_cpu_only(xgboost_framework_version): @pytest.mark.parametrize("xgboost_algo_version", ALGO_VERSIONS) def test_xgboost_algo(xgboost_algo_version): - for region in regions.regions(): + for region in ALGO_REGISTRIES.keys(): uri = image_uris.retrieve(framework="xgboost", region=region, version=xgboost_algo_version) expected = expected_uris.algo_uri(