Skip to content

Commit 0578a3d

Browse files
authored
change: add mxnet 1.7.0 eia configuration (#2173)
1 parent 4d3844b commit 0578a3d

File tree

4 files changed

+103
-6
lines changed

4 files changed

+103
-6
lines changed

src/sagemaker/image_uri_config/mxnet.json

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -694,7 +694,8 @@
694694
"version_aliases": {
695695
"1.3": "1.3.0",
696696
"1.4": "1.4.1",
697-
"1.5": "1.5.1"
697+
"1.5": "1.5.1",
698+
"1.7": "1.7.0"
698699
},
699700
"versions": {
700701
"1.3.0": {
@@ -816,6 +817,36 @@
816817
},
817818
"repository": "mxnet-inference-eia",
818819
"py_versions": ["py2", "py3"]
820+
},
821+
"1.7.0": {
822+
"registries": {
823+
"af-south-1": "626614931356",
824+
"ap-east-1": "871362719292",
825+
"ap-northeast-1": "763104351884",
826+
"ap-northeast-2": "763104351884",
827+
"ap-south-1": "763104351884",
828+
"ap-southeast-1": "763104351884",
829+
"ap-southeast-2": "763104351884",
830+
"ca-central-1": "763104351884",
831+
"cn-north-1": "727897471807",
832+
"cn-northwest-1": "727897471807",
833+
"eu-central-1": "763104351884",
834+
"eu-north-1": "763104351884",
835+
"eu-west-1": "763104351884",
836+
"eu-west-2": "763104351884",
837+
"eu-west-3": "763104351884",
838+
"eu-south-1": "692866216735",
839+
"me-south-1": "217643126080",
840+
"sa-east-1": "763104351884",
841+
"us-east-1": "763104351884",
842+
"us-east-2": "763104351884",
843+
"us-gov-west-1": "442386744353",
844+
"us-iso-east-1": "886529160074",
845+
"us-west-1": "763104351884",
846+
"us-west-2": "763104351884"
847+
},
848+
"repository": "mxnet-inference-eia",
849+
"py_versions": ["py3"]
819850
}
820851
}
821852
}

tests/conftest.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,16 @@ def mxnet_training_py_version(mxnet_training_version, request):
157157

158158

159159
@pytest.fixture(scope="module", params=["py2", "py3"])
160-
def mxnet_eia_py_version(request):
161-
return request.param
160+
def mxnet_eia_py_version(mxnet_eia_version, request):
161+
if Version(mxnet_eia_version) < Version("1.7.0"):
162+
return request.param
163+
else:
164+
return "py3"
165+
166+
167+
@pytest.fixture(scope="module")
168+
def mxnet_eia_latest_py_version():
169+
return "py3"
162170

163171

164172
@pytest.fixture(scope="module", params=["py2", "py3"])

tests/data/mxnet_mnist/mnist_ei.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Licensed under the Apache License, Version 2.0 (the "License"). You
2+
# may not use this file except in compliance with the License. A copy of
3+
# the License is located at
4+
#
5+
# http://aws.amazon.com/apache2.0/
6+
#
7+
# or in the "license" file accompanying this file. This file is
8+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
9+
# ANY KIND, either express or implied. See the License for the specific
10+
# language governing permissions and limitations under the License.
11+
from __future__ import absolute_import
12+
13+
import argparse
14+
import gzip
15+
import json
16+
import logging
17+
import os
18+
import struct
19+
20+
import mxnet as mx
21+
import numpy as np
22+
23+
24+
def model_fn(model_dir):
25+
import eimx
26+
27+
def read_data_shapes(path, preferred_batch_size=1):
28+
with open(path, "r") as f:
29+
signatures = json.load(f)
30+
31+
data_names = []
32+
data_shapes = []
33+
34+
for s in signatures:
35+
name = s["name"]
36+
data_names.append(name)
37+
38+
shape = s["shape"]
39+
40+
if preferred_batch_size:
41+
shape[0] = preferred_batch_size
42+
43+
data_shapes.append((name, shape))
44+
45+
return data_names, data_shapes
46+
47+
shapes_file = os.path.join(model_dir, "model-shapes.json")
48+
data_names, data_shapes = read_data_shapes(shapes_file)
49+
50+
ctx = mx.cpu()
51+
sym, args, aux = mx.model.load_checkpoint(os.path.join(model_dir, "model"), 0)
52+
sym = sym.optimize_for("EIA")
53+
54+
mod = mx.mod.Module(symbol=sym, context=ctx, data_names=data_names, label_names=None)
55+
mod.bind(for_training=False, data_shapes=data_shapes)
56+
mod.set_params(args, aux, allow_missing=True)
57+
58+
return mod

tests/integ/test_mxnet.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def test_deploy_model_with_accelerator(
313313
mxnet_training_job,
314314
sagemaker_session,
315315
mxnet_eia_latest_version,
316-
mxnet_eia_py_version,
316+
mxnet_eia_latest_py_version,
317317
cpu_instance_type,
318318
):
319319
endpoint_name = "test-mxnet-deploy-model-ei-{}".format(sagemaker_timestamp())
@@ -323,13 +323,13 @@ def test_deploy_model_with_accelerator(
323323
TrainingJobName=mxnet_training_job
324324
)
325325
model_data = desc["ModelArtifacts"]["S3ModelArtifacts"]
326-
script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist.py")
326+
script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist_ei.py")
327327
model = MXNetModel(
328328
model_data,
329329
"SageMakerRole",
330330
entry_point=script_path,
331331
framework_version=mxnet_eia_latest_version,
332-
py_version=mxnet_eia_py_version,
332+
py_version=mxnet_eia_latest_py_version,
333333
sagemaker_session=sagemaker_session,
334334
)
335335
predictor = model.deploy(

0 commit comments

Comments
 (0)