Skip to content

Commit 985005e

Browse files
authored
fix: unit tests for KIX and remove regional calls to boto (#2640)
* remove boto region fetch call and use locals where applicable
1 parent 8447430 commit 985005e

File tree

7 files changed

+49
-88
lines changed

7 files changed

+49
-88
lines changed

tests/unit/sagemaker/image_uris/test_algos.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import pytest
1616

1717
from sagemaker import image_uris
18-
from tests.unit.sagemaker.image_uris import expected_uris, regions
18+
from tests.unit.sagemaker.image_uris import expected_uris
1919

2020
ALGO_NAMES = (
2121
"blazingtext",
@@ -33,6 +33,7 @@
3333
"randomcutforest",
3434
"semantic-segmentation",
3535
"seq2seq",
36+
"lda",
3637
)
3738
ALGO_REGIONS_AND_ACCOUNTS = (
3839
{
@@ -176,21 +177,6 @@ def _accounts_for_algo(algo):
176177
@pytest.mark.parametrize("algo", ALGO_NAMES)
177178
def test_algo_uris(algo):
178179
accounts = _accounts_for_algo(algo)
179-
180-
for region in regions.regions():
180+
for region in accounts:
181181
uri = image_uris.retrieve(algo, region)
182182
assert expected_uris.algo_uri(algo, accounts[region], region) == uri
183-
184-
185-
def test_lda():
186-
algo = "lda"
187-
accounts = _accounts_for_algo(algo)
188-
189-
for region in regions.regions():
190-
if region in accounts:
191-
uri = image_uris.retrieve(algo, region)
192-
assert expected_uris.algo_uri(algo, accounts[region], region) == uri
193-
else:
194-
with pytest.raises(ValueError) as e:
195-
image_uris.retrieve(algo, region)
196-
assert "Unsupported region: {}.".format(region) in str(e.value)

tests/unit/sagemaker/image_uris/test_data_wrangler.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from __future__ import absolute_import
1414

1515
from sagemaker import image_uris
16-
from tests.unit.sagemaker.image_uris import expected_uris, regions
16+
from tests.unit.sagemaker.image_uris import expected_uris
1717

1818
DATA_WRANGLER_ACCOUNTS = {
1919
"af-south-1": "143210264188",
@@ -42,14 +42,12 @@
4242

4343

4444
def test_data_wrangler_ecr_uri():
45-
for region in regions.regions():
46-
if region in DATA_WRANGLER_ACCOUNTS.keys():
47-
actual_uri = image_uris.retrieve("data-wrangler", region=region)
48-
49-
expected_uri = expected_uris.algo_uri(
50-
"sagemaker-data-wrangler-container",
51-
DATA_WRANGLER_ACCOUNTS[region],
52-
region,
53-
version="1.x",
54-
)
55-
assert expected_uri == actual_uri
45+
for region in DATA_WRANGLER_ACCOUNTS.keys():
46+
actual_uri = image_uris.retrieve("data-wrangler", region=region)
47+
expected_uri = expected_uris.algo_uri(
48+
"sagemaker-data-wrangler-container",
49+
DATA_WRANGLER_ACCOUNTS[region],
50+
region,
51+
version="1.x",
52+
)
53+
assert expected_uri == actual_uri

tests/unit/sagemaker/image_uris/test_debugger.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@
1313
from __future__ import absolute_import
1414

1515
from sagemaker import image_uris
16-
from tests.unit.sagemaker.image_uris import expected_uris, regions
16+
from tests.unit.sagemaker.image_uris import expected_uris
1717

1818
ACCOUNTS = {
1919
"af-south-1": "314341159256",
2020
"ap-east-1": "199566480951",
2121
"ap-northeast-1": "430734990657",
2222
"ap-northeast-2": "578805364391",
23+
"ap-northeast-3": "479947661362",
2324
"ap-south-1": "904829902805",
2425
"ap-southeast-1": "972752614525",
2526
"ap-southeast-2": "184798709955",
@@ -43,11 +44,9 @@
4344

4445

4546
def test_debugger():
46-
for region in regions.regions():
47-
if region in ACCOUNTS.keys():
48-
uri = image_uris.retrieve("debugger", region=region)
49-
50-
expected = expected_uris.algo_uri(
51-
"sagemaker-debugger-rules", ACCOUNTS[region], region, version="latest"
52-
)
53-
assert expected == uri
47+
for region in ACCOUNTS.keys():
48+
uri = image_uris.retrieve("debugger", region=region)
49+
expected = expected_uris.algo_uri(
50+
"sagemaker-debugger-rules", ACCOUNTS[region], region, version="latest"
51+
)
52+
assert expected == uri

tests/unit/sagemaker/image_uris/test_neo.py

Lines changed: 19 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import pytest
1616

1717
from sagemaker import image_uris
18-
from tests.unit.sagemaker.image_uris import expected_uris, regions
18+
from tests.unit.sagemaker.image_uris import expected_uris
1919

2020
NEO_ALGOS = ("image-classification-neo", "xgboost-neo")
2121

@@ -24,6 +24,7 @@
2424
"ap-east-1": "110948597952",
2525
"ap-northeast-1": "941853720454",
2626
"ap-northeast-2": "151534178276",
27+
"ap-northeast-3": "925152966179",
2728
"ap-south-1": "763008648453",
2829
"ap-southeast-1": "324986816169",
2930
"ap-southeast-2": "355873309152",
@@ -50,33 +51,21 @@
5051

5152
@pytest.mark.parametrize("algo", NEO_ALGOS)
5253
def test_algo_uris(algo):
53-
for region in regions.regions():
54-
if region in ACCOUNTS:
55-
uri = image_uris.retrieve(algo, region)
56-
expected = expected_uris.algo_uri(algo, ACCOUNTS[region], region, version="latest")
57-
assert expected == uri
58-
else:
59-
with pytest.raises(ValueError) as e:
60-
image_uris.retrieve(algo, region)
61-
assert "Unsupported region: {}.".format(region) in str(e.value)
54+
for region in ACCOUNTS.keys():
55+
uri = image_uris.retrieve(algo, region)
56+
expected = expected_uris.algo_uri(algo, ACCOUNTS[region], region, version="latest")
57+
assert expected == uri
6258

6359

6460
def _test_neo_framework_uris(framework, version):
6561
framework_in_config = f"neo-{framework}"
6662
framework_in_uri = f"inference-{framework}"
6763

68-
for region in regions.regions():
69-
if region in ACCOUNTS:
70-
uri = image_uris.retrieve(
71-
framework_in_config, region, instance_type="ml_c5", version=version
72-
)
73-
assert _expected_framework_uri(framework_in_uri, version, region=region) == uri
74-
else:
75-
with pytest.raises(ValueError) as e:
76-
image_uris.retrieve(
77-
framework_in_config, region, instance_type="ml_c5", version=version
78-
)
79-
assert "Unsupported region: {}.".format(region) in str(e.value)
64+
for region in ACCOUNTS.keys():
65+
uri = image_uris.retrieve(
66+
framework_in_config, region, instance_type="ml_c5", version=version
67+
)
68+
assert _expected_framework_uri(framework_in_uri, version, region=region) == uri
8069

8170
uri = image_uris.retrieve(
8271
framework_in_config, "us-west-2", instance_type="ml_p2", version=version
@@ -97,24 +86,14 @@ def test_neo_pytorch(neo_pytorch_version):
9786

9887

9988
def _test_inferentia_framework_uris(framework, version):
100-
for region in regions.regions():
101-
if region in INFERENTIA_REGIONS:
102-
uri = image_uris.retrieve(
103-
"inferentia-{}".format(framework), region, instance_type="ml_inf1", version=version
104-
)
105-
expected = _expected_framework_uri(
106-
"neo-{}".format(framework), version, region=region, processor="inf"
107-
)
108-
assert expected == uri
109-
else:
110-
with pytest.raises(ValueError) as e:
111-
image_uris.retrieve(
112-
"inferentia-{}".format(framework),
113-
region,
114-
instance_type="ml_inf",
115-
version=version,
116-
)
117-
assert "Unsupported region: {}.".format(region) in str(e.value)
89+
for region in INFERENTIA_REGIONS:
90+
uri = image_uris.retrieve(
91+
"inferentia-{}".format(framework), region, instance_type="ml_inf1", version=version
92+
)
93+
expected = _expected_framework_uri(
94+
"neo-{}".format(framework), version, region=region, processor="inf"
95+
)
96+
assert expected == uri
11897

11998

12099
def test_inferentia_mxnet(inferentia_mxnet_version):

tests/unit/sagemaker/image_uris/test_sklearn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import pytest
1616

1717
from sagemaker import image_uris
18-
from tests.unit.sagemaker.image_uris import expected_uris, regions
18+
from tests.unit.sagemaker.image_uris import expected_uris
1919

2020
ACCOUNTS = {
2121
"af-south-1": "510948584623",
@@ -47,7 +47,7 @@
4747

4848

4949
def test_valid_uris(sklearn_version):
50-
for region in regions.regions():
50+
for region in ACCOUNTS.keys():
5151
uri = image_uris.retrieve(
5252
"sklearn",
5353
region=region,

tests/unit/sagemaker/image_uris/test_sparkml.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import pytest
1616

1717
from sagemaker import image_uris
18-
from tests.unit.sagemaker.image_uris import expected_uris, regions
18+
from tests.unit.sagemaker.image_uris import expected_uris
1919

2020
ACCOUNTS = {
2121
"af-south-1": "510948584623",
@@ -48,7 +48,7 @@
4848

4949
@pytest.mark.parametrize("version", VERSIONS)
5050
def test_sparkml(version):
51-
for region in regions.regions():
51+
for region in ACCOUNTS.keys():
5252
uri = image_uris.retrieve("sparkml-serving", region=region, version=version)
5353

5454
expected = expected_uris.algo_uri(

tests/unit/sagemaker/image_uris/test_xgboost.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,9 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
14-
1514
import pytest
16-
1715
from sagemaker import image_uris
18-
from tests.unit.sagemaker.image_uris import expected_uris, regions
16+
from tests.unit.sagemaker.image_uris import expected_uris
1917

2018
ALGO_REGISTRIES = {
2119
"af-south-1": "455444449433",
@@ -53,6 +51,7 @@
5351
"ap-east-1": "651117190479",
5452
"ap-northeast-1": "354813040037",
5553
"ap-northeast-2": "366743142698",
54+
"ap-northeast-3": "867004704886",
5655
"ap-south-1": "720646828776",
5756
"ap-southeast-1": "121021644041",
5857
"ap-southeast-2": "783357654285",
@@ -78,7 +77,7 @@
7877

7978
@pytest.mark.parametrize("xgboost_framework_version", XGBOOST_FRAMEWORK_CPU_GPU_VERSIONS)
8079
def test_xgboost_framework(xgboost_framework_version):
81-
for region in regions.regions():
80+
for region in FRAMEWORK_REGISTRIES.keys():
8281
uri = image_uris.retrieve(
8382
framework="xgboost",
8483
region=region,
@@ -98,7 +97,7 @@ def test_xgboost_framework(xgboost_framework_version):
9897

9998
@pytest.mark.parametrize("xgboost_framework_version", XGBOOST_FRAMEWORK_CPU_ONLY_VERSIONS)
10099
def test_xgboost_framework_cpu_only(xgboost_framework_version):
101-
for region in regions.regions():
100+
for region in FRAMEWORK_REGISTRIES.keys():
102101
uri = image_uris.retrieve(
103102
framework="xgboost",
104103
region=region,
@@ -118,7 +117,7 @@ def test_xgboost_framework_cpu_only(xgboost_framework_version):
118117

119118
@pytest.mark.parametrize("xgboost_algo_version", ALGO_VERSIONS)
120119
def test_xgboost_algo(xgboost_algo_version):
121-
for region in regions.regions():
120+
for region in ALGO_REGISTRIES.keys():
122121
uri = image_uris.retrieve(framework="xgboost", region=region, version=xgboost_algo_version)
123122

124123
expected = expected_uris.algo_uri(

0 commit comments

Comments
 (0)