Skip to content

Commit 03c277c

Browse files
authored
Add MXNet version to integ tests using MXNetModel (#626)
1 parent e1b459f commit 03c277c

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

tests/integ/test_mxnet_train.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,30 +57,32 @@ def test_attach_deploy(mxnet_training_job, sagemaker_session):
5757
predictor.predict(data)
5858

5959

60-
def test_deploy_model(mxnet_training_job, sagemaker_session):
60+
def test_deploy_model(mxnet_training_job, sagemaker_session, mxnet_full_version):
6161
endpoint_name = 'test-mxnet-deploy-model-{}'.format(sagemaker_timestamp())
6262

6363
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
6464
desc = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=mxnet_training_job)
6565
model_data = desc['ModelArtifacts']['S3ModelArtifacts']
6666
script_path = os.path.join(DATA_DIR, 'mxnet_mnist', 'mnist.py')
6767
model = MXNetModel(model_data, 'SageMakerRole', entry_point=script_path,
68-
py_version=PYTHON_VERSION, sagemaker_session=sagemaker_session)
68+
py_version=PYTHON_VERSION, sagemaker_session=sagemaker_session,
69+
framework_version=mxnet_full_version)
6970
predictor = model.deploy(1, 'ml.m4.xlarge', endpoint_name=endpoint_name)
7071

7172
data = numpy.zeros(shape=(1, 1, 28, 28))
7273
predictor.predict(data)
7374

7475

75-
def test_deploy_model_with_update_endpoint(mxnet_training_job, sagemaker_session):
76+
def test_deploy_model_with_update_endpoint(mxnet_training_job, sagemaker_session, mxnet_full_version):
7677
endpoint_name = 'test-mxnet-deploy-model-{}'.format(sagemaker_timestamp())
7778

7879
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
7980
desc = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=mxnet_training_job)
8081
model_data = desc['ModelArtifacts']['S3ModelArtifacts']
8182
script_path = os.path.join(DATA_DIR, 'mxnet_mnist', 'mnist.py')
8283
model = MXNetModel(model_data, 'SageMakerRole', entry_point=script_path,
83-
py_version=PYTHON_VERSION, sagemaker_session=sagemaker_session)
84+
py_version=PYTHON_VERSION, sagemaker_session=sagemaker_session,
85+
framework_version=mxnet_full_version)
8486
model.deploy(1, 'ml.t2.medium', endpoint_name=endpoint_name)
8587
old_endpoint = sagemaker_session.describe_endpoint(EndpointName=endpoint_name)
8688
old_config_name = old_endpoint['EndpointConfigName']
@@ -96,7 +98,7 @@ def test_deploy_model_with_update_endpoint(mxnet_training_job, sagemaker_session
9698
assert new_production_variants['AcceleratorType'] is None
9799

98100

99-
def test_deploy_model_with_update_non_existing_endpoint(mxnet_training_job, sagemaker_session):
101+
def test_deploy_model_with_update_non_existing_endpoint(mxnet_training_job, sagemaker_session, mxnet_full_version):
100102
endpoint_name = 'test-mxnet-deploy-model-{}'.format(sagemaker_timestamp())
101103
expected_error_message = 'Endpoint with name "{}" does not exist; ' \
102104
'please use an existing endpoint name'.format(endpoint_name)
@@ -106,7 +108,8 @@ def test_deploy_model_with_update_non_existing_endpoint(mxnet_training_job, sage
106108
model_data = desc['ModelArtifacts']['S3ModelArtifacts']
107109
script_path = os.path.join(DATA_DIR, 'mxnet_mnist', 'mnist.py')
108110
model = MXNetModel(model_data, 'SageMakerRole', entry_point=script_path,
109-
py_version=PYTHON_VERSION, sagemaker_session=sagemaker_session)
111+
py_version=PYTHON_VERSION, sagemaker_session=sagemaker_session,
112+
framework_version=mxnet_full_version)
110113
model.deploy(1, 'ml.t2.medium', endpoint_name=endpoint_name)
111114
sagemaker_session.describe_endpoint(EndpointName=endpoint_name)
112115

0 commit comments

Comments
 (0)