@@ -483,3 +483,64 @@ def test_deploy_predictor_cls(production_variant, sagemaker_session):
483
483
assert predictor_async .name == model .name
484
484
assert predictor_async .endpoint_name == endpoint_name_async
485
485
assert predictor_async .sagemaker_session == sagemaker_session
486
+
487
+ @patch ("sagemaker.production_variant" )
488
+ @patch ("sagemaker.model.Model.prepare_container_def" )
489
+ @patch ("sagemaker.utils.name_from_base" , return_value = MODEL_NAME )
490
+ def test_deploy_customized_volume_size_and_timeout (name_from_base , prepare_container_def , production_variant , sagemaker_session ):
491
+ volume_size_gb = 256
492
+ model_data_download_timeout_sec = 1800
493
+ startup_health_check_timeout_sec = 1800
494
+
495
+ production_variant_result = copy .deepcopy (BASE_PRODUCTION_VARIANT )
496
+ production_variant_result .update ({
497
+ 'VolumeSizeInGB' : volume_size_gb ,
498
+ 'ModelDataDownloadTimeoutInSeconds' : model_data_download_timeout_sec ,
499
+ 'ContainerStartupHealthCheckTimeoutInSeconds' : startup_health_check_timeout_sec ,
500
+ })
501
+ production_variant .return_value = production_variant_result
502
+
503
+ container_def = {"Image" : MODEL_IMAGE , "Environment" : {}, "ModelDataUrl" : MODEL_DATA }
504
+ prepare_container_def .return_value = container_def
505
+
506
+ model = Model (MODEL_IMAGE , MODEL_DATA , role = ROLE , sagemaker_session = sagemaker_session )
507
+ model .deploy (instance_type = INSTANCE_TYPE , initial_instance_count = INSTANCE_COUNT ,
508
+ volume_size = volume_size_gb ,
509
+ model_data_download_timeout = model_data_download_timeout_sec ,
510
+ container_startup_health_check_timeout = startup_health_check_timeout_sec )
511
+
512
+ name_from_base .assert_called_with (MODEL_IMAGE )
513
+ assert 2 == name_from_base .call_count
514
+
515
+ prepare_container_def .assert_called_with (
516
+ INSTANCE_TYPE , accelerator_type = None , serverless_inference_config = None
517
+ )
518
+ production_variant .assert_called_with (
519
+ MODEL_NAME ,
520
+ INSTANCE_TYPE ,
521
+ INSTANCE_COUNT ,
522
+ accelerator_type = None ,
523
+ serverless_inference_config = None ,
524
+ volume_size = volume_size_gb ,
525
+ model_data_download_timeout = model_data_download_timeout_sec ,
526
+ container_startup_health_check_timeout = startup_health_check_timeout_sec ,
527
+ )
528
+
529
+ sagemaker_session .create_model .assert_called_with (
530
+ name = MODEL_NAME ,
531
+ role = ROLE ,
532
+ container_defs = container_def ,
533
+ vpc_config = None ,
534
+ enable_network_isolation = False ,
535
+ tags = None ,
536
+ )
537
+
538
+ sagemaker_session .endpoint_from_production_variants .assert_called_with (
539
+ name = MODEL_NAME ,
540
+ production_variants = [production_variant_result ],
541
+ tags = None ,
542
+ kms_key = None ,
543
+ wait = True ,
544
+ data_capture_config_dict = None ,
545
+ async_inference_config_dict = None ,
546
+ )
0 commit comments