Skip to content

Commit 1130fa6

Browse files
committed
fixed linting issues
1 parent a4f22d8 commit 1130fa6

File tree

3 files changed

+25
-16
lines changed

3 files changed

+25
-16
lines changed

src/sagemaker/fw_utils.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@
6464
"mxnet": "mxnet-training",
6565
"tensorflow-serving": "tensorflow-inference",
6666
"tensorflow-serving-eia": "tensorflow-inference-eia",
67-
"mxnet-serving": "mxnet-inference",
6867
"mxnet-serving-eia": "mxnet-inference-eia",
6968
}
7069

@@ -73,7 +72,6 @@
7372
"mxnet": [1, 4, 1],
7473
"tensorflow-serving": [1, 13, 0],
7574
"tensorflow-serving-eia": [1, 14, 0],
76-
"mxnet-serving": [1, 4, 1],
7775
"mxnet-serving-eia": [1, 4, 1],
7876
}
7977

@@ -212,6 +210,13 @@ def create_image_uri(
212210
if py_version and py_version not in VALID_PY_VERSIONS:
213211
raise ValueError("invalid py_version argument: {}".format(py_version))
214212

213+
if _accelerator_type_valid_for_framework(
214+
framework=framework,
215+
accelerator_type=accelerator_type,
216+
optimized_families=optimized_families,
217+
):
218+
framework += "-eia"
219+
215220
# Handle Account Number for Gov Cloud and frameworks with DLC merged images
216221
account = _registry_id(
217222
region=region,
@@ -242,13 +247,6 @@ def create_image_uri(
242247
else:
243248
device_type = "cpu"
244249

245-
if _accelerator_type_valid_for_framework(
246-
framework=framework,
247-
accelerator_type=accelerator_type,
248-
optimized_families=optimized_families,
249-
):
250-
framework += "-eia"
251-
252250
using_merged_images = _using_merged_images(region, framework, py_version, framework_version)
253251

254252
if not py_version or (using_merged_images and framework == "tensorflow-serving-eia"):

tests/unit/test_fw_utils.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,23 @@ def test_create_image_uri_merged():
218218
image_uri = fw_utils.create_image_uri(
219219
"us-west-2", "mxnet-serving", "ml.c4.2xlarge", "1.4.1", "py3"
220220
)
221-
assert image_uri == "763104351884.dkr.ecr.us-west-2.amazonaws.com/mxnet-inference:1.4.1-cpu-py3"
221+
assert (
222+
image_uri
223+
== "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet-serving:1.4.1-cpu-py3"
224+
)
225+
226+
image_uri = fw_utils.create_image_uri(
227+
"us-west-2",
228+
"mxnet-serving",
229+
"ml.c4.2xlarge",
230+
"1.4.1",
231+
"py3",
232+
accelerator_type="ml.eia1.medium",
233+
)
234+
assert (
235+
image_uri
236+
== "763104351884.dkr.ecr.us-west-2.amazonaws.com/mxnet-inference-eia:1.4.1-cpu-py3"
237+
)
222238

223239

224240
def test_create_image_uri_merged_py2():

tests/unit/test_mxnet.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -336,12 +336,7 @@ def test_mxnet_mms_version(
336336

337337
model = mx.create_model()
338338

339-
if mxnet_version == "1.4.1":
340-
expected_image_base = (
341-
"763104351884.dkr.ecr.us-west-2.amazonaws.com/mxnet-inference:1.4.1-gpu-py2"
342-
)
343-
else:
344-
expected_image_base = _get_full_image_uri(mxnet_version, IMAGE_REPO_SERVING_NAME, "gpu")
339+
expected_image_base = _get_full_image_uri(mxnet_version, IMAGE_REPO_SERVING_NAME, "gpu")
345340

346341
environment = {
347342
"Environment": {

0 commit comments

Comments
 (0)