15
15
import pytest
16
16
17
17
from sagemaker import image_uris
18
- from tests .unit .sagemaker .image_uris import expected_uris , regions
18
+ from tests .unit .sagemaker .image_uris import expected_uris
19
19
20
20
NEO_ALGOS = ("image-classification-neo" , "xgboost-neo" )
21
21
24
24
"ap-east-1" : "110948597952" ,
25
25
"ap-northeast-1" : "941853720454" ,
26
26
"ap-northeast-2" : "151534178276" ,
27
+ "ap-northeast-3" : "925152966179" ,
27
28
"ap-south-1" : "763008648453" ,
28
29
"ap-southeast-1" : "324986816169" ,
29
30
"ap-southeast-2" : "355873309152" ,
50
51
51
52
@pytest .mark .parametrize ("algo" , NEO_ALGOS )
52
53
def test_algo_uris (algo ):
53
- for region in regions .regions ():
54
- if region in ACCOUNTS :
55
- uri = image_uris .retrieve (algo , region )
56
- expected = expected_uris .algo_uri (algo , ACCOUNTS [region ], region , version = "latest" )
57
- assert expected == uri
58
- else :
59
- with pytest .raises (ValueError ) as e :
60
- image_uris .retrieve (algo , region )
61
- assert "Unsupported region: {}." .format (region ) in str (e .value )
54
+ for region in ACCOUNTS .keys ():
55
+ uri = image_uris .retrieve (algo , region )
56
+ expected = expected_uris .algo_uri (algo , ACCOUNTS [region ], region , version = "latest" )
57
+ assert expected == uri
62
58
63
59
64
60
def _test_neo_framework_uris (framework , version ):
65
61
framework_in_config = f"neo-{ framework } "
66
62
framework_in_uri = f"inference-{ framework } "
67
63
68
- for region in regions .regions ():
69
- if region in ACCOUNTS :
70
- uri = image_uris .retrieve (
71
- framework_in_config , region , instance_type = "ml_c5" , version = version
72
- )
73
- assert _expected_framework_uri (framework_in_uri , version , region = region ) == uri
74
- else :
75
- with pytest .raises (ValueError ) as e :
76
- image_uris .retrieve (
77
- framework_in_config , region , instance_type = "ml_c5" , version = version
78
- )
79
- assert "Unsupported region: {}." .format (region ) in str (e .value )
64
+ for region in ACCOUNTS .keys ():
65
+ uri = image_uris .retrieve (
66
+ framework_in_config , region , instance_type = "ml_c5" , version = version
67
+ )
68
+ assert _expected_framework_uri (framework_in_uri , version , region = region ) == uri
80
69
81
70
uri = image_uris .retrieve (
82
71
framework_in_config , "us-west-2" , instance_type = "ml_p2" , version = version
@@ -97,24 +86,14 @@ def test_neo_pytorch(neo_pytorch_version):
97
86
98
87
99
88
def _test_inferentia_framework_uris (framework , version ):
100
- for region in regions .regions ():
101
- if region in INFERENTIA_REGIONS :
102
- uri = image_uris .retrieve (
103
- "inferentia-{}" .format (framework ), region , instance_type = "ml_inf1" , version = version
104
- )
105
- expected = _expected_framework_uri (
106
- "neo-{}" .format (framework ), version , region = region , processor = "inf"
107
- )
108
- assert expected == uri
109
- else :
110
- with pytest .raises (ValueError ) as e :
111
- image_uris .retrieve (
112
- "inferentia-{}" .format (framework ),
113
- region ,
114
- instance_type = "ml_inf" ,
115
- version = version ,
116
- )
117
- assert "Unsupported region: {}." .format (region ) in str (e .value )
89
+ for region in INFERENTIA_REGIONS :
90
+ uri = image_uris .retrieve (
91
+ "inferentia-{}" .format (framework ), region , instance_type = "ml_inf1" , version = version
92
+ )
93
+ expected = _expected_framework_uri (
94
+ "neo-{}" .format (framework ), version , region = region , processor = "inf"
95
+ )
96
+ assert expected == uri
118
97
119
98
120
99
def test_inferentia_mxnet (inferentia_mxnet_version ):
0 commit comments