diff --git a/tests/unit/sagemaker/model/test_deploy.py b/tests/unit/sagemaker/model/test_deploy.py index ed911b4021..dd1bb0e57c 100644 --- a/tests/unit/sagemaker/model/test_deploy.py +++ b/tests/unit/sagemaker/model/test_deploy.py @@ -25,6 +25,7 @@ TIMESTAMP = "2017-10-10-14-14-15" MODEL_NAME = "{}-{}".format(MODEL_IMAGE, TIMESTAMP) +ACCELERATOR_TYPE = "ml.eia.medium" INSTANCE_COUNT = 2 INSTANCE_TYPE = "ml.c4.4xlarge" ROLE = "some-role" @@ -83,21 +84,19 @@ def test_deploy_accelerator_type(production_variant, create_sagemaker_model, sag MODEL_DATA, MODEL_IMAGE, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session ) - accelerator_type = "ml.eia.medium" - production_variant_result = copy.deepcopy(BASE_PRODUCTION_VARIANT) - production_variant_result["AcceleratorType"] = accelerator_type + production_variant_result["AcceleratorType"] = ACCELERATOR_TYPE production_variant.return_value = production_variant_result model.deploy( instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT, - accelerator_type=accelerator_type, + accelerator_type=ACCELERATOR_TYPE, ) - create_sagemaker_model.assert_called_with(INSTANCE_TYPE, accelerator_type, None) + create_sagemaker_model.assert_called_with(INSTANCE_TYPE, ACCELERATOR_TYPE, None) production_variant.assert_called_with( - MODEL_NAME, INSTANCE_TYPE, INSTANCE_COUNT, accelerator_type=accelerator_type + MODEL_NAME, INSTANCE_TYPE, INSTANCE_COUNT, accelerator_type=ACCELERATOR_TYPE ) sagemaker_session.endpoint_from_production_variants.assert_called_with( @@ -267,3 +266,69 @@ def test_deploy_predictor_cls(production_variant, sagemaker_session): assert isinstance(predictor, sagemaker.predictor.RealTimePredictor) assert predictor.endpoint == endpoint_name assert predictor.sagemaker_session == sagemaker_session + + +def test_deploy_update_endpoint(sagemaker_session): + model = Model(MODEL_DATA, MODEL_IMAGE, role=ROLE, sagemaker_session=sagemaker_session) + model.deploy( + instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT, update_endpoint=True + ) + sagemaker_session.create_endpoint_config.assert_called_with( + name=model.name, + model_name=model.name, + initial_instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + accelerator_type=None, + tags=None, + kms_key=None, + data_capture_config_dict=None, + ) + config_name = sagemaker_session.create_endpoint_config( + name=model.name, + model_name=model.name, + initial_instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + accelerator_type=ACCELERATOR_TYPE, + ) + sagemaker_session.update_endpoint.assert_called_with(model.name, config_name, wait=True) + sagemaker_session.create_endpoint.assert_not_called() + + +def test_deploy_update_endpoint_optional_args(sagemaker_session): + endpoint_name = "endpoint-name" + tags = [{"Key": "Value"}] + kms_key = "foo" + data_capture_config = Mock() + + model = Model(MODEL_DATA, MODEL_IMAGE, role=ROLE, sagemaker_session=sagemaker_session) + model.deploy( + instance_type=INSTANCE_TYPE, + initial_instance_count=INSTANCE_COUNT, + update_endpoint=True, + endpoint_name=endpoint_name, + accelerator_type=ACCELERATOR_TYPE, + tags=tags, + kms_key=kms_key, + wait=False, + data_capture_config=data_capture_config, + ) + sagemaker_session.create_endpoint_config.assert_called_with( + name=model.name, + model_name=model.name, + initial_instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + accelerator_type=ACCELERATOR_TYPE, + tags=tags, + kms_key=kms_key, + data_capture_config_dict=data_capture_config._to_request_dict(), + ) + config_name = sagemaker_session.create_endpoint_config( + name=model.name, + model_name=model.name, + initial_instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + accelerator_type=ACCELERATOR_TYPE, + wait=False, + ) + sagemaker_session.update_endpoint.assert_called_with(endpoint_name, config_name, wait=False) + sagemaker_session.create_endpoint.assert_not_called() diff --git a/tests/unit/sagemaker/model/test_framework_model.py b/tests/unit/sagemaker/model/test_framework_model.py index 09cb5531ce..33f52316ef 100644 --- a/tests/unit/sagemaker/model/test_framework_model.py +++ b/tests/unit/sagemaker/model/test_framework_model.py @@ -33,8 +33,6 @@ BUCKET_NAME = "mybucket" INSTANCE_COUNT = 1 INSTANCE_TYPE = "c4.4xlarge" -ACCELERATOR_TYPE = "ml.eia.medium" -IMAGE_NAME = "fakeimage" REGION = "us-west-2" MODEL_NAME = "{}-{}".format(MODEL_IMAGE, TIMESTAMP) GIT_REPO = "https://github.com/aws/sagemaker-python-sdk.git" @@ -165,72 +163,6 @@ def test_prepare_container_def_no_model_defaults(sagemaker_session, tmpdir): } -@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) -def test_deploy_update_endpoint(sagemaker_session, tmpdir): - model = DummyFrameworkModel(sagemaker_session, source_dir=tmpdir) - model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1, update_endpoint=True) - sagemaker_session.create_endpoint_config.assert_called_with( - name=model.name, - model_name=model.name, - initial_instance_count=INSTANCE_COUNT, - instance_type=INSTANCE_TYPE, - accelerator_type=None, - tags=None, - kms_key=None, - data_capture_config_dict=None, - ) - config_name = sagemaker_session.create_endpoint_config( - name=model.name, - model_name=model.name, - initial_instance_count=INSTANCE_COUNT, - instance_type=INSTANCE_TYPE, - accelerator_type=ACCELERATOR_TYPE, - ) - sagemaker_session.update_endpoint.assert_called_with(model.name, config_name, wait=True) - sagemaker_session.create_endpoint.assert_not_called() - - -@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) -def test_deploy_update_endpoint_optional_args(sagemaker_session, tmpdir): - endpoint_name = "endpoint-name" - tags = [{"Key": "Value"}] - kms_key = "foo" - data_capture_config = MagicMock() - - model = DummyFrameworkModel(sagemaker_session, source_dir=tmpdir) - model.deploy( - instance_type=INSTANCE_TYPE, - initial_instance_count=1, - update_endpoint=True, - endpoint_name=endpoint_name, - accelerator_type=ACCELERATOR_TYPE, - tags=tags, - kms_key=kms_key, - wait=False, - data_capture_config=data_capture_config, - ) - sagemaker_session.create_endpoint_config.assert_called_with( - name=model.name, - model_name=model.name, - initial_instance_count=INSTANCE_COUNT, - instance_type=INSTANCE_TYPE, - accelerator_type=ACCELERATOR_TYPE, - tags=tags, - kms_key=kms_key, - data_capture_config_dict=data_capture_config._to_request_dict(), - ) - config_name = sagemaker_session.create_endpoint_config( - name=model.name, - model_name=model.name, - initial_instance_count=INSTANCE_COUNT, - instance_type=INSTANCE_TYPE, - accelerator_type=ACCELERATOR_TYPE, - wait=False, - ) - sagemaker_session.update_endpoint.assert_called_with(endpoint_name, config_name, wait=False) - sagemaker_session.create_endpoint.assert_not_called() - - @patch("sagemaker.git_utils.git_clone_repo") @patch("sagemaker.model.fw_utils.tar_and_upload_dir") def test_git_support_succeed(tar_and_upload_dir, git_clone_repo, sagemaker_session):