diff --git a/src/sagemaker/image_uri_config/mxnet.json b/src/sagemaker/image_uri_config/mxnet.json index 1e47f2cc69..53aa005d7e 100644 --- a/src/sagemaker/image_uri_config/mxnet.json +++ b/src/sagemaker/image_uri_config/mxnet.json @@ -694,7 +694,8 @@ "version_aliases": { "1.3": "1.3.0", "1.4": "1.4.1", - "1.5": "1.5.1" + "1.5": "1.5.1", + "1.7": "1.7.0" }, "versions": { "1.3.0": { @@ -816,6 +817,36 @@ }, "repository": "mxnet-inference-eia", "py_versions": ["py2", "py3"] + }, + "1.7.0": { + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "mxnet-inference-eia", + "py_versions": ["py3"] } } } diff --git a/tests/conftest.py b/tests/conftest.py index fd88019091..6e9794fe07 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -157,8 +157,16 @@ def mxnet_training_py_version(mxnet_training_version, request): @pytest.fixture(scope="module", params=["py2", "py3"]) -def mxnet_eia_py_version(request): - return request.param +def mxnet_eia_py_version(mxnet_eia_version, request): + if Version(mxnet_eia_version) < Version("1.7.0"): + return request.param + else: + return "py3" + + +@pytest.fixture(scope="module") +def mxnet_eia_latest_py_version(): + return "py3" @pytest.fixture(scope="module", params=["py2", "py3"]) diff --git a/tests/data/mxnet_mnist/mnist_ei.py b/tests/data/mxnet_mnist/mnist_ei.py new file mode 100644 index 0000000000..7f0b2ea684 --- /dev/null +++ b/tests/data/mxnet_mnist/mnist_ei.py @@ -0,0 +1,58 @@ +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import argparse +import gzip +import json +import logging +import os +import struct + +import mxnet as mx +import numpy as np + + +def model_fn(model_dir): + import eimx + + def read_data_shapes(path, preferred_batch_size=1): + with open(path, "r") as f: + signatures = json.load(f) + + data_names = [] + data_shapes = [] + + for s in signatures: + name = s["name"] + data_names.append(name) + + shape = s["shape"] + + if preferred_batch_size: + shape[0] = preferred_batch_size + + data_shapes.append((name, shape)) + + return data_names, data_shapes + + shapes_file = os.path.join(model_dir, "model-shapes.json") + data_names, data_shapes = read_data_shapes(shapes_file) + + ctx = mx.cpu() + sym, args, aux = mx.model.load_checkpoint(os.path.join(model_dir, "model"), 0) + sym = sym.optimize_for("EIA") + + mod = mx.mod.Module(symbol=sym, context=ctx, data_names=data_names, label_names=None) + mod.bind(for_training=False, data_shapes=data_shapes) + mod.set_params(args, aux, allow_missing=True) + + return mod diff --git a/tests/integ/test_mxnet.py b/tests/integ/test_mxnet.py index a583fd05cc..65917df663 100644 --- a/tests/integ/test_mxnet.py +++ b/tests/integ/test_mxnet.py @@ -313,7 +313,7 @@ def test_deploy_model_with_accelerator( mxnet_training_job, sagemaker_session, mxnet_eia_latest_version, - mxnet_eia_py_version, + mxnet_eia_latest_py_version, cpu_instance_type, ): endpoint_name = "test-mxnet-deploy-model-ei-{}".format(sagemaker_timestamp()) @@ -323,13 +323,13 @@ def test_deploy_model_with_accelerator( TrainingJobName=mxnet_training_job ) model_data = desc["ModelArtifacts"]["S3ModelArtifacts"] - script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist.py") + script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist_ei.py") model = MXNetModel( model_data, "SageMakerRole", entry_point=script_path, framework_version=mxnet_eia_latest_version, - py_version=mxnet_eia_py_version, + py_version=mxnet_eia_latest_py_version, sagemaker_session=sagemaker_session, ) predictor = model.deploy(