Skip to content

Commit 852ffdd

Browse files
committed
Model and Estimator UTs
1 parent 483b143 commit 852ffdd

File tree

2 files changed

+105
-0
lines changed

2 files changed

+105
-0
lines changed

tests/unit/sagemaker/model/test_deploy.py

+61
Original file line numberDiff line numberDiff line change
@@ -483,3 +483,64 @@ def test_deploy_predictor_cls(production_variant, sagemaker_session):
483483
assert predictor_async.name == model.name
484484
assert predictor_async.endpoint_name == endpoint_name_async
485485
assert predictor_async.sagemaker_session == sagemaker_session
486+
487+
@patch("sagemaker.production_variant")
488+
@patch("sagemaker.model.Model.prepare_container_def")
489+
@patch("sagemaker.utils.name_from_base", return_value=MODEL_NAME)
490+
def test_deploy_customized_volume_size_and_timeout(name_from_base, prepare_container_def, production_variant, sagemaker_session):
491+
volume_size_gb = 256
492+
model_data_download_timeout_sec = 1800
493+
startup_health_check_timeout_sec = 1800
494+
495+
production_variant_result = copy.deepcopy(BASE_PRODUCTION_VARIANT)
496+
production_variant_result.update({
497+
'VolumeSizeInGB': volume_size_gb,
498+
'ModelDataDownloadTimeoutInSeconds': model_data_download_timeout_sec,
499+
'ContainerStartupHealthCheckTimeoutInSeconds': startup_health_check_timeout_sec,
500+
})
501+
production_variant.return_value = production_variant_result
502+
503+
container_def = {"Image": MODEL_IMAGE, "Environment": {}, "ModelDataUrl": MODEL_DATA}
504+
prepare_container_def.return_value = container_def
505+
506+
model = Model(MODEL_IMAGE, MODEL_DATA, role=ROLE, sagemaker_session=sagemaker_session)
507+
model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT,
508+
volume_size=volume_size_gb,
509+
model_data_download_timeout=model_data_download_timeout_sec,
510+
container_startup_health_check_timeout=startup_health_check_timeout_sec)
511+
512+
name_from_base.assert_called_with(MODEL_IMAGE)
513+
assert 2 == name_from_base.call_count
514+
515+
prepare_container_def.assert_called_with(
516+
INSTANCE_TYPE, accelerator_type=None, serverless_inference_config=None
517+
)
518+
production_variant.assert_called_with(
519+
MODEL_NAME,
520+
INSTANCE_TYPE,
521+
INSTANCE_COUNT,
522+
accelerator_type=None,
523+
serverless_inference_config=None,
524+
volume_size=volume_size_gb,
525+
model_data_download_timeout=model_data_download_timeout_sec,
526+
container_startup_health_check_timeout=startup_health_check_timeout_sec,
527+
)
528+
529+
sagemaker_session.create_model.assert_called_with(
530+
name=MODEL_NAME,
531+
role=ROLE,
532+
container_defs=container_def,
533+
vpc_config=None,
534+
enable_network_isolation=False,
535+
tags=None,
536+
)
537+
538+
sagemaker_session.endpoint_from_production_variants.assert_called_with(
539+
name=MODEL_NAME,
540+
production_variants=[production_variant_result],
541+
tags=None,
542+
kms_key=None,
543+
wait=True,
544+
data_capture_config_dict=None,
545+
async_inference_config_dict=None,
546+
)

tests/unit/test_estimator.py

+44
Original file line numberDiff line numberDiff line change
@@ -3312,6 +3312,50 @@ def test_deploy_with_no_model_name(sagemaker_session):
33123312
assert kwargs["name"].startswith(IMAGE_URI)
33133313

33143314

3315+
@patch("sagemaker.estimator.Estimator.create_model")
3316+
def test_deploy_with_customized_volume_size_timeout(create_model, sagemaker_session):
3317+
estimator = Estimator(
3318+
IMAGE_URI,
3319+
ROLE,
3320+
INSTANCE_COUNT,
3321+
INSTANCE_TYPE,
3322+
output_path=OUTPUT_PATH,
3323+
sagemaker_session=sagemaker_session,
3324+
)
3325+
estimator.set_hyperparameters(**HYPERPARAMS)
3326+
estimator.fit({"train": "s3://bucket/training-prefix"})
3327+
endpoint_name = "endpoint-name"
3328+
volume_size_gb = 256
3329+
model_data_download_timeout_sec = 600
3330+
startup_health_check_timeout_sec = 600
3331+
3332+
model = MagicMock()
3333+
create_model.return_value = model
3334+
3335+
estimator.deploy(INSTANCE_COUNT, INSTANCE_TYPE, endpoint_name=endpoint_name,
3336+
volume_size=volume_size_gb,
3337+
model_data_download_timeout=model_data_download_timeout_sec,
3338+
container_startup_health_check_timeout=startup_health_check_timeout_sec)
3339+
3340+
model.deploy.assert_called_with(
3341+
instance_type=INSTANCE_TYPE,
3342+
initial_instance_count=INSTANCE_COUNT,
3343+
serializer=None,
3344+
deserializer=None,
3345+
accelerator_type=None,
3346+
endpoint_name=endpoint_name,
3347+
tags=None,
3348+
wait=True,
3349+
kms_key=None,
3350+
data_capture_config=None,
3351+
async_inference_config=None,
3352+
serverless_inference_config=None,
3353+
volume_size=volume_size_gb,
3354+
model_data_download_timeout=model_data_download_timeout_sec,
3355+
container_startup_health_check_timeout=startup_health_check_timeout_sec,
3356+
)
3357+
3358+
33153359
def test_register_default_image(sagemaker_session):
33163360
estimator = Estimator(
33173361
IMAGE_URI,

0 commit comments

Comments
 (0)