@@ -89,7 +89,7 @@ def test_graviton_pytorch(graviton_pytorch_version):
89
89
_test_graviton_framework_uris ("pytorch" , graviton_pytorch_version )
90
90
91
91
92
- def test_graviton_xgboost (graviton_xgboost_versions ):
92
+ def test_graviton_xgboost_instance_type_specified (graviton_xgboost_versions ):
93
93
for xgboost_version in graviton_xgboost_versions :
94
94
for instance_type in GRAVITON_INSTANCE_TYPES :
95
95
uri = image_uris .retrieve (
@@ -102,6 +102,19 @@ def test_graviton_xgboost(graviton_xgboost_versions):
102
102
assert expected == uri
103
103
104
104
105
+ def test_graviton_xgboost_image_scope_specified (graviton_xgboost_versions ):
106
+ for xgboost_version in graviton_xgboost_versions :
107
+ for instance_type in GRAVITON_INSTANCE_TYPES :
108
+ uri = image_uris .retrieve (
109
+ "xgboost" , "us-west-2" , version = xgboost_version , image_scope = "inference_graviton"
110
+ )
111
+ expected = (
112
+ "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:"
113
+ f"{ xgboost_version } -arm64"
114
+ )
115
+ assert expected == uri
116
+
117
+
105
118
def test_graviton_xgboost_unsupported_version (graviton_xgboost_unsupported_versions ):
106
119
for xgboost_version in graviton_xgboost_unsupported_versions :
107
120
for instance_type in GRAVITON_INSTANCE_TYPES :
@@ -112,7 +125,7 @@ def test_graviton_xgboost_unsupported_version(graviton_xgboost_unsupported_versi
112
125
assert f"Unsupported xgboost version: { xgboost_version } ." in str (error )
113
126
114
127
115
- def test_graviton_sklearn (graviton_sklearn_versions ):
128
+ def test_graviton_sklearn_instance_type_specified (graviton_sklearn_versions ):
116
129
for sklearn_version in graviton_sklearn_versions :
117
130
for instance_type in GRAVITON_INSTANCE_TYPES :
118
131
uri = image_uris .retrieve (
@@ -125,6 +138,19 @@ def test_graviton_sklearn(graviton_sklearn_versions):
125
138
assert expected == uri
126
139
127
140
141
+ def test_graviton_sklearn_image_scope_specified (graviton_sklearn_versions ):
142
+ for sklearn_version in graviton_sklearn_versions :
143
+ for instance_type in GRAVITON_INSTANCE_TYPES :
144
+ uri = image_uris .retrieve (
145
+ "sklearn" , "us-west-2" , version = sklearn_version , image_scope = "inference_graviton"
146
+ )
147
+ expected = (
148
+ "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:"
149
+ f"{ sklearn_version } -arm64-cpu-py3"
150
+ )
151
+ assert expected == uri
152
+
153
+
128
154
def test_graviton_sklearn_unsupported_version (graviton_sklearn_unsupported_versions ):
129
155
for sklearn_version in graviton_sklearn_unsupported_versions :
130
156
for instance_type in GRAVITON_INSTANCE_TYPES :
0 commit comments