|
25 | 25 | TIMESTAMP = "2017-10-10-14-14-15"
|
26 | 26 | MODEL_NAME = "{}-{}".format(MODEL_IMAGE, TIMESTAMP)
|
27 | 27 |
|
| 28 | +ACCELERATOR_TYPE = "ml.eia.medium" |
28 | 29 | INSTANCE_COUNT = 2
|
29 | 30 | INSTANCE_TYPE = "ml.c4.4xlarge"
|
30 | 31 | ROLE = "some-role"
|
@@ -83,21 +84,19 @@ def test_deploy_accelerator_type(production_variant, create_sagemaker_model, sag
|
83 | 84 | MODEL_DATA, MODEL_IMAGE, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session
|
84 | 85 | )
|
85 | 86 |
|
86 |
| - accelerator_type = "ml.eia.medium" |
87 |
| - |
88 | 87 | production_variant_result = copy.deepcopy(BASE_PRODUCTION_VARIANT)
|
89 |
| - production_variant_result["AcceleratorType"] = accelerator_type |
| 88 | + production_variant_result["AcceleratorType"] = ACCELERATOR_TYPE |
90 | 89 | production_variant.return_value = production_variant_result
|
91 | 90 |
|
92 | 91 | model.deploy(
|
93 | 92 | instance_type=INSTANCE_TYPE,
|
94 | 93 | initial_instance_count=INSTANCE_COUNT,
|
95 |
| - accelerator_type=accelerator_type, |
| 94 | + accelerator_type=ACCELERATOR_TYPE, |
96 | 95 | )
|
97 | 96 |
|
98 |
| - create_sagemaker_model.assert_called_with(INSTANCE_TYPE, accelerator_type, None) |
| 97 | + create_sagemaker_model.assert_called_with(INSTANCE_TYPE, ACCELERATOR_TYPE, None) |
99 | 98 | production_variant.assert_called_with(
|
100 |
| - MODEL_NAME, INSTANCE_TYPE, INSTANCE_COUNT, accelerator_type=accelerator_type |
| 99 | + MODEL_NAME, INSTANCE_TYPE, INSTANCE_COUNT, accelerator_type=ACCELERATOR_TYPE |
101 | 100 | )
|
102 | 101 |
|
103 | 102 | sagemaker_session.endpoint_from_production_variants.assert_called_with(
|
@@ -267,3 +266,69 @@ def test_deploy_predictor_cls(production_variant, sagemaker_session):
|
267 | 266 | assert isinstance(predictor, sagemaker.predictor.RealTimePredictor)
|
268 | 267 | assert predictor.endpoint == endpoint_name
|
269 | 268 | assert predictor.sagemaker_session == sagemaker_session
|
| 269 | + |
| 270 | + |
| 271 | +def test_deploy_update_endpoint(sagemaker_session): |
| 272 | + model = Model(MODEL_DATA, MODEL_IMAGE, role=ROLE, sagemaker_session=sagemaker_session) |
| 273 | + model.deploy( |
| 274 | + instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT, update_endpoint=True |
| 275 | + ) |
| 276 | + sagemaker_session.create_endpoint_config.assert_called_with( |
| 277 | + name=model.name, |
| 278 | + model_name=model.name, |
| 279 | + initial_instance_count=INSTANCE_COUNT, |
| 280 | + instance_type=INSTANCE_TYPE, |
| 281 | + accelerator_type=None, |
| 282 | + tags=None, |
| 283 | + kms_key=None, |
| 284 | + data_capture_config_dict=None, |
| 285 | + ) |
| 286 | + config_name = sagemaker_session.create_endpoint_config( |
| 287 | + name=model.name, |
| 288 | + model_name=model.name, |
| 289 | + initial_instance_count=INSTANCE_COUNT, |
| 290 | + instance_type=INSTANCE_TYPE, |
| 291 | + accelerator_type=ACCELERATOR_TYPE, |
| 292 | + ) |
| 293 | + sagemaker_session.update_endpoint.assert_called_with(model.name, config_name, wait=True) |
| 294 | + sagemaker_session.create_endpoint.assert_not_called() |
| 295 | + |
| 296 | + |
| 297 | +def test_deploy_update_endpoint_optional_args(sagemaker_session): |
| 298 | + endpoint_name = "endpoint-name" |
| 299 | + tags = [{"Key": "Value"}] |
| 300 | + kms_key = "foo" |
| 301 | + data_capture_config = Mock() |
| 302 | + |
| 303 | + model = Model(MODEL_DATA, MODEL_IMAGE, role=ROLE, sagemaker_session=sagemaker_session) |
| 304 | + model.deploy( |
| 305 | + instance_type=INSTANCE_TYPE, |
| 306 | + initial_instance_count=INSTANCE_COUNT, |
| 307 | + update_endpoint=True, |
| 308 | + endpoint_name=endpoint_name, |
| 309 | + accelerator_type=ACCELERATOR_TYPE, |
| 310 | + tags=tags, |
| 311 | + kms_key=kms_key, |
| 312 | + wait=False, |
| 313 | + data_capture_config=data_capture_config, |
| 314 | + ) |
| 315 | + sagemaker_session.create_endpoint_config.assert_called_with( |
| 316 | + name=model.name, |
| 317 | + model_name=model.name, |
| 318 | + initial_instance_count=INSTANCE_COUNT, |
| 319 | + instance_type=INSTANCE_TYPE, |
| 320 | + accelerator_type=ACCELERATOR_TYPE, |
| 321 | + tags=tags, |
| 322 | + kms_key=kms_key, |
| 323 | + data_capture_config_dict=data_capture_config._to_request_dict(), |
| 324 | + ) |
| 325 | + config_name = sagemaker_session.create_endpoint_config( |
| 326 | + name=model.name, |
| 327 | + model_name=model.name, |
| 328 | + initial_instance_count=INSTANCE_COUNT, |
| 329 | + instance_type=INSTANCE_TYPE, |
| 330 | + accelerator_type=ACCELERATOR_TYPE, |
| 331 | + wait=False, |
| 332 | + ) |
| 333 | + sagemaker_session.update_endpoint.assert_called_with(endpoint_name, config_name, wait=False) |
| 334 | + sagemaker_session.create_endpoint.assert_not_called() |
0 commit comments