Skip to content

fix: unit tests for KIX and remove regional calls to boto #2640

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 3 additions & 17 deletions tests/unit/sagemaker/image_uris/test_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import pytest

from sagemaker import image_uris
from tests.unit.sagemaker.image_uris import expected_uris, regions
from tests.unit.sagemaker.image_uris import expected_uris

ALGO_NAMES = (
"blazingtext",
Expand All @@ -33,6 +33,7 @@
"randomcutforest",
"semantic-segmentation",
"seq2seq",
"lda",
)
ALGO_REGIONS_AND_ACCOUNTS = (
{
Expand Down Expand Up @@ -176,21 +177,6 @@ def _accounts_for_algo(algo):
@pytest.mark.parametrize("algo", ALGO_NAMES)
def test_algo_uris(algo):
accounts = _accounts_for_algo(algo)

for region in regions.regions():
for region in accounts:
uri = image_uris.retrieve(algo, region)
assert expected_uris.algo_uri(algo, accounts[region], region) == uri


def test_lda():
algo = "lda"
accounts = _accounts_for_algo(algo)

for region in regions.regions():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactoring this test since it is a duplicacy and the LDA param can be instead pulled into the config and tested from above

if region in accounts:
uri = image_uris.retrieve(algo, region)
assert expected_uris.algo_uri(algo, accounts[region], region) == uri
else:
with pytest.raises(ValueError) as e:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing these if/else blocks because 1. The if is really filtering the regions and can be combined with the loop above, and since we are removing use of boto regions anyways this makes no sense.

The ValueError raised in else is not something that should be done in a test, especially if we are checking only supported regions, and seems redundant.

image_uris.retrieve(algo, region)
assert "Unsupported region: {}.".format(region) in str(e.value)
22 changes: 10 additions & 12 deletions tests/unit/sagemaker/image_uris/test_data_wrangler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from __future__ import absolute_import

from sagemaker import image_uris
from tests.unit.sagemaker.image_uris import expected_uris, regions
from tests.unit.sagemaker.image_uris import expected_uris

DATA_WRANGLER_ACCOUNTS = {
"af-south-1": "143210264188",
Expand Down Expand Up @@ -42,14 +42,12 @@


def test_data_wrangler_ecr_uri():
for region in regions.regions():
if region in DATA_WRANGLER_ACCOUNTS.keys():
actual_uri = image_uris.retrieve("data-wrangler", region=region)

expected_uri = expected_uris.algo_uri(
"sagemaker-data-wrangler-container",
DATA_WRANGLER_ACCOUNTS[region],
region,
version="1.x",
)
assert expected_uri == actual_uri
for region in DATA_WRANGLER_ACCOUNTS.keys():
actual_uri = image_uris.retrieve("data-wrangler", region=region)
expected_uri = expected_uris.algo_uri(
"sagemaker-data-wrangler-container",
DATA_WRANGLER_ACCOUNTS[region],
region,
version="1.x",
)
assert expected_uri == actual_uri
17 changes: 8 additions & 9 deletions tests/unit/sagemaker/image_uris/test_debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
from __future__ import absolute_import

from sagemaker import image_uris
from tests.unit.sagemaker.image_uris import expected_uris, regions
from tests.unit.sagemaker.image_uris import expected_uris

ACCOUNTS = {
"af-south-1": "314341159256",
"ap-east-1": "199566480951",
"ap-northeast-1": "430734990657",
"ap-northeast-2": "578805364391",
"ap-northeast-3": "479947661362",
"ap-south-1": "904829902805",
"ap-southeast-1": "972752614525",
"ap-southeast-2": "184798709955",
Expand All @@ -43,11 +44,9 @@


def test_debugger():
for region in regions.regions():
if region in ACCOUNTS.keys():
uri = image_uris.retrieve("debugger", region=region)

expected = expected_uris.algo_uri(
"sagemaker-debugger-rules", ACCOUNTS[region], region, version="latest"
)
assert expected == uri
for region in ACCOUNTS.keys():
uri = image_uris.retrieve("debugger", region=region)
expected = expected_uris.algo_uri(
"sagemaker-debugger-rules", ACCOUNTS[region], region, version="latest"
)
assert expected == uri
59 changes: 19 additions & 40 deletions tests/unit/sagemaker/image_uris/test_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import pytest

from sagemaker import image_uris
from tests.unit.sagemaker.image_uris import expected_uris, regions
from tests.unit.sagemaker.image_uris import expected_uris

NEO_ALGOS = ("image-classification-neo", "xgboost-neo")

Expand All @@ -24,6 +24,7 @@
"ap-east-1": "110948597952",
"ap-northeast-1": "941853720454",
"ap-northeast-2": "151534178276",
"ap-northeast-3": "925152966179",
"ap-south-1": "763008648453",
"ap-southeast-1": "324986816169",
"ap-southeast-2": "355873309152",
Expand All @@ -50,33 +51,21 @@

@pytest.mark.parametrize("algo", NEO_ALGOS)
def test_algo_uris(algo):
for region in regions.regions():
if region in ACCOUNTS:
uri = image_uris.retrieve(algo, region)
expected = expected_uris.algo_uri(algo, ACCOUNTS[region], region, version="latest")
assert expected == uri
else:
with pytest.raises(ValueError) as e:
image_uris.retrieve(algo, region)
assert "Unsupported region: {}.".format(region) in str(e.value)
for region in ACCOUNTS.keys():
uri = image_uris.retrieve(algo, region)
expected = expected_uris.algo_uri(algo, ACCOUNTS[region], region, version="latest")
assert expected == uri


def _test_neo_framework_uris(framework, version):
framework_in_config = f"neo-{framework}"
framework_in_uri = f"inference-{framework}"

for region in regions.regions():
if region in ACCOUNTS:
uri = image_uris.retrieve(
framework_in_config, region, instance_type="ml_c5", version=version
)
assert _expected_framework_uri(framework_in_uri, version, region=region) == uri
else:
with pytest.raises(ValueError) as e:
image_uris.retrieve(
framework_in_config, region, instance_type="ml_c5", version=version
)
assert "Unsupported region: {}.".format(region) in str(e.value)
for region in ACCOUNTS.keys():
uri = image_uris.retrieve(
framework_in_config, region, instance_type="ml_c5", version=version
)
assert _expected_framework_uri(framework_in_uri, version, region=region) == uri

uri = image_uris.retrieve(
framework_in_config, "us-west-2", instance_type="ml_p2", version=version
Expand All @@ -97,24 +86,14 @@ def test_neo_pytorch(neo_pytorch_version):


def _test_inferentia_framework_uris(framework, version):
for region in regions.regions():
if region in INFERENTIA_REGIONS:
uri = image_uris.retrieve(
"inferentia-{}".format(framework), region, instance_type="ml_inf1", version=version
)
expected = _expected_framework_uri(
"neo-{}".format(framework), version, region=region, processor="inf"
)
assert expected == uri
else:
with pytest.raises(ValueError) as e:
image_uris.retrieve(
"inferentia-{}".format(framework),
region,
instance_type="ml_inf",
version=version,
)
assert "Unsupported region: {}.".format(region) in str(e.value)
for region in INFERENTIA_REGIONS:
uri = image_uris.retrieve(
"inferentia-{}".format(framework), region, instance_type="ml_inf1", version=version
)
expected = _expected_framework_uri(
"neo-{}".format(framework), version, region=region, processor="inf"
)
assert expected == uri


def test_inferentia_mxnet(inferentia_mxnet_version):
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/sagemaker/image_uris/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import pytest

from sagemaker import image_uris
from tests.unit.sagemaker.image_uris import expected_uris, regions
from tests.unit.sagemaker.image_uris import expected_uris

ACCOUNTS = {
"af-south-1": "510948584623",
Expand Down Expand Up @@ -47,7 +47,7 @@


def test_valid_uris(sklearn_version):
for region in regions.regions():
for region in ACCOUNTS.keys():
uri = image_uris.retrieve(
"sklearn",
region=region,
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/sagemaker/image_uris/test_sparkml.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import pytest

from sagemaker import image_uris
from tests.unit.sagemaker.image_uris import expected_uris, regions
from tests.unit.sagemaker.image_uris import expected_uris

ACCOUNTS = {
"af-south-1": "510948584623",
Expand Down Expand Up @@ -48,7 +48,7 @@

@pytest.mark.parametrize("version", VERSIONS)
def test_sparkml(version):
for region in regions.regions():
for region in ACCOUNTS.keys():
uri = image_uris.retrieve("sparkml-serving", region=region, version=version)

expected = expected_uris.algo_uri(
Expand Down
11 changes: 5 additions & 6 deletions tests/unit/sagemaker/image_uris/test_xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import pytest

from sagemaker import image_uris
from tests.unit.sagemaker.image_uris import expected_uris, regions
from tests.unit.sagemaker.image_uris import expected_uris

ALGO_REGISTRIES = {
"af-south-1": "455444449433",
Expand Down Expand Up @@ -53,6 +51,7 @@
"ap-east-1": "651117190479",
"ap-northeast-1": "354813040037",
"ap-northeast-2": "366743142698",
"ap-northeast-3": "867004704886",
"ap-south-1": "720646828776",
"ap-southeast-1": "121021644041",
"ap-southeast-2": "783357654285",
Expand All @@ -78,7 +77,7 @@

@pytest.mark.parametrize("xgboost_framework_version", XGBOOST_FRAMEWORK_CPU_GPU_VERSIONS)
def test_xgboost_framework(xgboost_framework_version):
for region in regions.regions():
for region in FRAMEWORK_REGISTRIES.keys():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are removing regions.region what are we testing here now? Seems to me we are testing image_uris.retrieve() to expected_uris.framework_uri() which is unnecessary.

uri = image_uris.retrieve(
framework="xgboost",
region=region,
Expand All @@ -98,7 +97,7 @@ def test_xgboost_framework(xgboost_framework_version):

@pytest.mark.parametrize("xgboost_framework_version", XGBOOST_FRAMEWORK_CPU_ONLY_VERSIONS)
def test_xgboost_framework_cpu_only(xgboost_framework_version):
for region in regions.regions():
for region in FRAMEWORK_REGISTRIES.keys():
uri = image_uris.retrieve(
framework="xgboost",
region=region,
Expand All @@ -118,7 +117,7 @@ def test_xgboost_framework_cpu_only(xgboost_framework_version):

@pytest.mark.parametrize("xgboost_algo_version", ALGO_VERSIONS)
def test_xgboost_algo(xgboost_algo_version):
for region in regions.regions():
for region in ALGO_REGISTRIES.keys():
uri = image_uris.retrieve(framework="xgboost", region=region, version=xgboost_algo_version)

expected = expected_uris.algo_uri(
Expand Down