@@ -159,6 +159,7 @@ def test_mnist_async(sagemaker_session):
159
159
training_job_name = estimator .latest_training_job .name
160
160
time .sleep (20 )
161
161
endpoint_name = training_job_name
162
+ model_name = 'model-name-1'
162
163
_assert_training_job_tags_match (
163
164
sagemaker_session .sagemaker_client , estimator .latest_training_job .name , TAGS
164
165
)
@@ -167,7 +168,8 @@ def test_mnist_async(sagemaker_session):
167
168
training_job_name = training_job_name , sagemaker_session = sagemaker_session
168
169
)
169
170
predictor = estimator .deploy (
170
- initial_instance_count = 1 , instance_type = "ml.c4.xlarge" , endpoint_name = endpoint_name
171
+ initial_instance_count = 1 , instance_type = "ml.c4.xlarge" , endpoint_name = endpoint_name ,
172
+ model_name = model_name
171
173
)
172
174
173
175
result = predictor .predict (np .zeros (784 ))
@@ -176,6 +178,9 @@ def test_mnist_async(sagemaker_session):
176
178
_assert_model_tags_match (
177
179
sagemaker_session .sagemaker_client , estimator .latest_training_job .name , TAGS
178
180
)
181
+ _assert_model_name_match (
182
+ sagemaker_session .sagemaker_client , endpoint_name , model_name
183
+ )
179
184
180
185
181
186
def test_deploy_with_input_handlers (sagemaker_session , instance_type ):
@@ -241,3 +246,10 @@ def _assert_training_job_tags_match(sagemaker_client, training_job_name, tags):
241
246
TrainingJobName = training_job_name
242
247
)
243
248
_assert_tags_match (sagemaker_client , training_job_description ["TrainingJobArn" ], tags )
249
+
250
+
251
+ def _assert_model_name_match (sagemaker_client , endpoint_config_name , model_name ):
252
+ endpoint_config_description = sagemaker_client .describe_endpoint_config (
253
+ EndpointConfigName = endpoint_config_name
254
+ )
255
+ assert model_name == endpoint_config_description ['ProductionVariants' ][0 ]['ModelName' ]
0 commit comments