@@ -484,30 +484,38 @@ def test_deploy_predictor_cls(production_variant, sagemaker_session):
484
484
assert predictor_async .endpoint_name == endpoint_name_async
485
485
assert predictor_async .sagemaker_session == sagemaker_session
486
486
487
+
487
488
@patch ("sagemaker.production_variant" )
488
489
@patch ("sagemaker.model.Model.prepare_container_def" )
489
490
@patch ("sagemaker.utils.name_from_base" , return_value = MODEL_NAME )
490
- def test_deploy_customized_volume_size_and_timeout (name_from_base , prepare_container_def , production_variant , sagemaker_session ):
491
+ def test_deploy_customized_volume_size_and_timeout (
492
+ name_from_base , prepare_container_def , production_variant , sagemaker_session
493
+ ):
491
494
volume_size_gb = 256
492
495
model_data_download_timeout_sec = 1800
493
496
startup_health_check_timeout_sec = 1800
494
497
495
498
production_variant_result = copy .deepcopy (BASE_PRODUCTION_VARIANT )
496
- production_variant_result .update ({
497
- 'VolumeSizeInGB' : volume_size_gb ,
498
- 'ModelDataDownloadTimeoutInSeconds' : model_data_download_timeout_sec ,
499
- 'ContainerStartupHealthCheckTimeoutInSeconds' : startup_health_check_timeout_sec ,
500
- })
499
+ production_variant_result .update (
500
+ {
501
+ "VolumeSizeInGB" : volume_size_gb ,
502
+ "ModelDataDownloadTimeoutInSeconds" : model_data_download_timeout_sec ,
503
+ "ContainerStartupHealthCheckTimeoutInSeconds" : startup_health_check_timeout_sec ,
504
+ }
505
+ )
501
506
production_variant .return_value = production_variant_result
502
507
503
508
container_def = {"Image" : MODEL_IMAGE , "Environment" : {}, "ModelDataUrl" : MODEL_DATA }
504
509
prepare_container_def .return_value = container_def
505
510
506
511
model = Model (MODEL_IMAGE , MODEL_DATA , role = ROLE , sagemaker_session = sagemaker_session )
507
- model .deploy (instance_type = INSTANCE_TYPE , initial_instance_count = INSTANCE_COUNT ,
508
- volume_size = volume_size_gb ,
509
- model_data_download_timeout = model_data_download_timeout_sec ,
510
- container_startup_health_check_timeout = startup_health_check_timeout_sec )
512
+ model .deploy (
513
+ instance_type = INSTANCE_TYPE ,
514
+ initial_instance_count = INSTANCE_COUNT ,
515
+ volume_size = volume_size_gb ,
516
+ model_data_download_timeout = model_data_download_timeout_sec ,
517
+ container_startup_health_check_timeout = startup_health_check_timeout_sec ,
518
+ )
511
519
512
520
name_from_base .assert_called_with (MODEL_IMAGE )
513
521
assert 2 == name_from_base .call_count
0 commit comments