From 1927b1c7d2b9ad8c66899363181383fc79578ed3 Mon Sep 17 00:00:00 2001 From: Lauren Yu <6631887+laurenyu@users.noreply.github.com> Date: Wed, 15 Apr 2020 11:47:51 -0700 Subject: [PATCH 1/4] infra: split Model and ModelPackage unit tests into different files --- .../unit/{ => sagemaker/model}/test_model.py | 116 +--------------- .../sagemaker/model/test_model_package.py | 128 ++++++++++++++++++ 2 files changed, 129 insertions(+), 115 deletions(-) rename tests/unit/{ => sagemaker/model}/test_model.py (87%) create mode 100644 tests/unit/sagemaker/model/test_model_package.py diff --git a/tests/unit/test_model.py b/tests/unit/sagemaker/model/test_model.py similarity index 87% rename from tests/unit/test_model.py rename to tests/unit/sagemaker/model/test_model.py index 5c69566b14..462ed31b23 100644 --- a/tests/unit/test_model.py +++ b/tests/unit/sagemaker/model/test_model.py @@ -12,12 +12,11 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import -import copy import os import subprocess import sagemaker -from sagemaker.model import FrameworkModel, ModelPackage +from sagemaker.model import FrameworkModel from sagemaker.predictor import RealTimePredictor import pytest @@ -53,39 +52,6 @@ CODECOMMIT_BRANCH = "master" REPO_DIR = "/tmp/repo_dir" - -DESCRIBE_MODEL_PACKAGE_RESPONSE = { - "InferenceSpecification": { - "SupportedResponseMIMETypes": ["text"], - "SupportedContentTypes": ["text/csv"], - "SupportedTransformInstanceTypes": ["ml.m4.xlarge", "ml.m4.2xlarge"], - "Containers": [ - { - "Image": "1.dkr.ecr.us-east-2.amazonaws.com/decision-trees-sample:latest", - "ImageDigest": "sha256:1234556789", - "ModelDataUrl": "s3://bucket/output/model.tar.gz", - } - ], - "SupportedRealtimeInferenceInstanceTypes": ["ml.m4.xlarge", "ml.m4.2xlarge"], - }, - "ModelPackageDescription": "Model Package created from training with " - "arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", - "CreationTime": 1542752036.687, - "ModelPackageArn": "arn:aws:sagemaker:us-east-2:123:model-package/mp-scikit-decision-trees", - "ModelPackageStatusDetails": {"ValidationStatuses": [], "ImageScanStatuses": []}, - "SourceAlgorithmSpecification": { - "SourceAlgorithms": [ - { - "ModelDataUrl": "s3://bucket/output/model.tar.gz", - "AlgorithmName": "arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", - } - ] - }, - "ModelPackageStatus": "Completed", - "ModelPackageName": "mp-scikit-decision-trees-1542410022-2018-11-20-22-13-56-502", - "CertifyForMarketplace": False, -} - DESCRIBE_COMPILATION_JOB_RESPONSE = { "CompilationJobStatus": "Completed", "ModelArtifacts": {"S3ModelArtifacts": "s3://output-path/model.tar.gz"}, @@ -419,10 +385,6 @@ def test_model_enable_network_isolation(sagemaker_session): @patch("sagemaker.model.Model._create_sagemaker_model", Mock()) def test_model_create_transformer(sagemaker_session): - sagemaker_session.sagemaker_client.describe_model_package = Mock( - return_value=DESCRIBE_MODEL_PACKAGE_RESPONSE - ) - tags = [{"Key": "k", "Value": "v"}] model = DummyFrameworkModel(sagemaker_session=sagemaker_session) instance_type = "ml.m4.xlarge" @@ -453,82 +415,6 @@ def test_transformer_creates_correct_session(local_session, session): assert transformer.sagemaker_session == session.return_value -def test_model_package_enable_network_isolation_with_no_product_id(sagemaker_session): - sagemaker_session.sagemaker_client.describe_model_package = Mock( - return_value=DESCRIBE_MODEL_PACKAGE_RESPONSE - ) - - model_package = ModelPackage( - role="role", model_package_arn="my-model-package", sagemaker_session=sagemaker_session - ) - assert model_package.enable_network_isolation() is False - - -def test_model_package_enable_network_isolation_with_product_id(sagemaker_session): - model_package_response = copy.deepcopy(DESCRIBE_MODEL_PACKAGE_RESPONSE) - model_package_response["InferenceSpecification"]["Containers"].append( - { - "Image": "1.dkr.ecr.us-east-2.amazonaws.com/some-container:latest", - "ModelDataUrl": "s3://bucket/output/model.tar.gz", - "ProductId": "some-product-id", - } - ) - sagemaker_session.sagemaker_client.describe_model_package = Mock( - return_value=model_package_response - ) - - model_package = ModelPackage( - role="role", model_package_arn="my-model-package", sagemaker_session=sagemaker_session - ) - assert model_package.enable_network_isolation() is True - - -@patch("sagemaker.model.ModelPackage._create_sagemaker_model", Mock()) -def test_model_package_create_transformer(sagemaker_session): - sagemaker_session.sagemaker_client.describe_model_package = Mock( - return_value=DESCRIBE_MODEL_PACKAGE_RESPONSE - ) - - model_package = ModelPackage( - role="role", model_package_arn="my-model-package", sagemaker_session=sagemaker_session - ) - model_package.name = "auto-generated-model" - transformer = model_package.transformer( - instance_count=1, instance_type="ml.m4.xlarge", env={"test": True} - ) - assert isinstance(transformer, sagemaker.transformer.Transformer) - assert transformer.model_name == "auto-generated-model" - assert transformer.instance_type == "ml.m4.xlarge" - assert transformer.env == {"test": True} - - -@patch("sagemaker.model.ModelPackage._create_sagemaker_model", Mock()) -def test_model_package_create_transformer_with_product_id(sagemaker_session): - model_package_response = copy.deepcopy(DESCRIBE_MODEL_PACKAGE_RESPONSE) - model_package_response["InferenceSpecification"]["Containers"].append( - { - "Image": "1.dkr.ecr.us-east-2.amazonaws.com/some-container:latest", - "ModelDataUrl": "s3://bucket/output/model.tar.gz", - "ProductId": "some-product-id", - } - ) - sagemaker_session.sagemaker_client.describe_model_package = Mock( - return_value=model_package_response - ) - - model_package = ModelPackage( - role="role", model_package_arn="my-model-package", sagemaker_session=sagemaker_session - ) - model_package.name = "auto-generated-model" - transformer = model_package.transformer( - instance_count=1, instance_type="ml.m4.xlarge", env={"test": True} - ) - assert isinstance(transformer, sagemaker.transformer.Transformer) - assert transformer.model_name == "auto-generated-model" - assert transformer.instance_type == "ml.m4.xlarge" - assert transformer.env is None - - @patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) @patch("time.strftime", MagicMock(return_value=TIMESTAMP)) def test_model_delete_model(sagemaker_session, tmpdir): diff --git a/tests/unit/sagemaker/model/test_model_package.py b/tests/unit/sagemaker/model/test_model_package.py new file mode 100644 index 0000000000..b74671144f --- /dev/null +++ b/tests/unit/sagemaker/model/test_model_package.py @@ -0,0 +1,128 @@ +# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import copy + +from mock import Mock, patch + +import sagemaker +from sagemaker.model import ModelPackage + +DESCRIBE_MODEL_PACKAGE_RESPONSE = { + "InferenceSpecification": { + "SupportedResponseMIMETypes": ["text"], + "SupportedContentTypes": ["text/csv"], + "SupportedTransformInstanceTypes": ["ml.m4.xlarge", "ml.m4.2xlarge"], + "Containers": [ + { + "Image": "1.dkr.ecr.us-east-2.amazonaws.com/decision-trees-sample:latest", + "ImageDigest": "sha256:1234556789", + "ModelDataUrl": "s3://bucket/output/model.tar.gz", + } + ], + "SupportedRealtimeInferenceInstanceTypes": ["ml.m4.xlarge", "ml.m4.2xlarge"], + }, + "ModelPackageDescription": "Model Package created from training with " + "arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", + "CreationTime": 1542752036.687, + "ModelPackageArn": "arn:aws:sagemaker:us-east-2:123:model-package/mp-scikit-decision-trees", + "ModelPackageStatusDetails": {"ValidationStatuses": [], "ImageScanStatuses": []}, + "SourceAlgorithmSpecification": { + "SourceAlgorithms": [ + { + "ModelDataUrl": "s3://bucket/output/model.tar.gz", + "AlgorithmName": "arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", + } + ] + }, + "ModelPackageStatus": "Completed", + "ModelPackageName": "mp-scikit-decision-trees-1542410022-2018-11-20-22-13-56-502", + "CertifyForMarketplace": False, +} + + +def test_model_package_enable_network_isolation_with_no_product_id(sagemaker_session): + sagemaker_session.sagemaker_client.describe_model_package = Mock( + return_value=DESCRIBE_MODEL_PACKAGE_RESPONSE + ) + + model_package = ModelPackage( + role="role", model_package_arn="my-model-package", sagemaker_session=sagemaker_session + ) + assert model_package.enable_network_isolation() is False + + +def test_model_package_enable_network_isolation_with_product_id(sagemaker_session): + model_package_response = copy.deepcopy(DESCRIBE_MODEL_PACKAGE_RESPONSE) + model_package_response["InferenceSpecification"]["Containers"].append( + { + "Image": "1.dkr.ecr.us-east-2.amazonaws.com/some-container:latest", + "ModelDataUrl": "s3://bucket/output/model.tar.gz", + "ProductId": "some-product-id", + } + ) + sagemaker_session.sagemaker_client.describe_model_package = Mock( + return_value=model_package_response + ) + + model_package = ModelPackage( + role="role", model_package_arn="my-model-package", sagemaker_session=sagemaker_session + ) + assert model_package.enable_network_isolation() is True + + +@patch("sagemaker.model.ModelPackage._create_sagemaker_model", Mock()) +def test_model_package_create_transformer(sagemaker_session): + sagemaker_session.sagemaker_client.describe_model_package = Mock( + return_value=DESCRIBE_MODEL_PACKAGE_RESPONSE + ) + + model_package = ModelPackage( + role="role", model_package_arn="my-model-package", sagemaker_session=sagemaker_session + ) + model_package.name = "auto-generated-model" + transformer = model_package.transformer( + instance_count=1, instance_type="ml.m4.xlarge", env={"test": True} + ) + assert isinstance(transformer, sagemaker.transformer.Transformer) + assert transformer.model_name == "auto-generated-model" + assert transformer.instance_type == "ml.m4.xlarge" + assert transformer.env == {"test": True} + + +@patch("sagemaker.model.ModelPackage._create_sagemaker_model", Mock()) +def test_model_package_create_transformer_with_product_id(sagemaker_session): + model_package_response = copy.deepcopy(DESCRIBE_MODEL_PACKAGE_RESPONSE) + model_package_response["InferenceSpecification"]["Containers"].append( + { + "Image": "1.dkr.ecr.us-east-2.amazonaws.com/some-container:latest", + "ModelDataUrl": "s3://bucket/output/model.tar.gz", + "ProductId": "some-product-id", + } + ) + sagemaker_session.sagemaker_client.describe_model_package = Mock( + return_value=model_package_response + ) + + model_package = ModelPackage( + role="role", model_package_arn="my-model-package", sagemaker_session=sagemaker_session + ) + model_package.name = "auto-generated-model" + transformer = model_package.transformer( + instance_count=1, instance_type="ml.m4.xlarge", env={"test": True} + ) + assert isinstance(transformer, sagemaker.transformer.Transformer) + assert transformer.model_name == "auto-generated-model" + assert transformer.instance_type == "ml.m4.xlarge" + assert transformer.env is None From 27189a6e1069067f269334cce08e769476347f3e Mon Sep 17 00:00:00 2001 From: Lauren Yu <6631887+laurenyu@users.noreply.github.com> Date: Wed, 15 Apr 2020 13:55:44 -0700 Subject: [PATCH 2/4] split Model and FrameworkModel tests into different files --- .../sagemaker/model/test_framework_model.py | 868 ++++++++++++++++++ tests/unit/sagemaker/model/test_model.py | 854 +---------------- 2 files changed, 871 insertions(+), 851 deletions(-) create mode 100644 tests/unit/sagemaker/model/test_framework_model.py diff --git a/tests/unit/sagemaker/model/test_framework_model.py b/tests/unit/sagemaker/model/test_framework_model.py new file mode 100644 index 0000000000..635a59ae79 --- /dev/null +++ b/tests/unit/sagemaker/model/test_framework_model.py @@ -0,0 +1,868 @@ +# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import os +import subprocess + +from sagemaker.model import FrameworkModel +from sagemaker.predictor import RealTimePredictor + +import pytest +from mock import MagicMock, Mock, patch + +MODEL_DATA = "s3://bucket/model.tar.gz" +MODEL_IMAGE = "mi" +ENTRY_POINT = "blah.py" +INSTANCE_TYPE = "p2.xlarge" +ROLE = "some-role" + +DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") +SCRIPT_NAME = "dummy_script.py" +SCRIPT_PATH = os.path.join(DATA_DIR, SCRIPT_NAME) +TIMESTAMP = "2017-10-10-14-14-15" +BUCKET_NAME = "mybucket" +INSTANCE_COUNT = 1 +INSTANCE_TYPE = "c4.4xlarge" +ACCELERATOR_TYPE = "ml.eia.medium" +IMAGE_NAME = "fakeimage" +REGION = "us-west-2" +NEO_REGION_ACCOUNT = "301217895009" +MODEL_NAME = "{}-{}".format(MODEL_IMAGE, TIMESTAMP) +GIT_REPO = "https://github.com/aws/sagemaker-python-sdk.git" +BRANCH = "test-branch-git-config" +COMMIT = "ae15c9d7d5b97ea95ea451e4662ee43da3401d73" +PRIVATE_GIT_REPO_SSH = "git@github.com:testAccount/private-repo.git" +PRIVATE_GIT_REPO = "https://github.com/testAccount/private-repo.git" +PRIVATE_BRANCH = "test-branch" +PRIVATE_COMMIT = "329bfcf884482002c05ff7f44f62599ebc9f445a" +CODECOMMIT_REPO = "https://git-codecommit.us-west-2.amazonaws.com/v1/repos/test-repo/" +CODECOMMIT_REPO_SSH = "ssh://git-codecommit.us-west-2.amazonaws.com/v1/repos/test-repo/" +CODECOMMIT_BRANCH = "master" +REPO_DIR = "/tmp/repo_dir" + +DESCRIBE_COMPILATION_JOB_RESPONSE = { + "CompilationJobStatus": "Completed", + "ModelArtifacts": {"S3ModelArtifacts": "s3://output-path/model.tar.gz"}, +} + + +class DummyFrameworkModel(FrameworkModel): + def __init__(self, sagemaker_session, **kwargs): + super(DummyFrameworkModel, self).__init__( + MODEL_DATA, + MODEL_IMAGE, + ROLE, + ENTRY_POINT, + sagemaker_session=sagemaker_session, + **kwargs + ) + + def create_predictor(self, endpoint_name): + return RealTimePredictor(endpoint_name, sagemaker_session=self.sagemaker_session) + + +class DummyFrameworkModelForGit(FrameworkModel): + def __init__(self, sagemaker_session, entry_point, **kwargs): + super(DummyFrameworkModelForGit, self).__init__( + MODEL_DATA, + MODEL_IMAGE, + ROLE, + entry_point=entry_point, + sagemaker_session=sagemaker_session, + **kwargs + ) + + def create_predictor(self, endpoint_name): + return RealTimePredictor(endpoint_name, sagemaker_session=self.sagemaker_session) + + +@pytest.fixture() +def sagemaker_session(): + boto_mock = Mock(name="boto_session", region_name=REGION) + sms = Mock( + name="sagemaker_session", + boto_session=boto_mock, + boto_region_name=REGION, + config=None, + local_mode=False, + s3_client=None, + s3_resource=None, + ) + sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) + return sms + + +@patch("shutil.rmtree", MagicMock()) +@patch("tarfile.open", MagicMock()) +@patch("os.listdir", MagicMock(return_value=["blah.py"])) +@patch("time.strftime", return_value=TIMESTAMP) +def test_prepare_container_def(time, sagemaker_session): + model = DummyFrameworkModel(sagemaker_session) + assert model.prepare_container_def(INSTANCE_TYPE) == { + "Environment": { + "SAGEMAKER_PROGRAM": ENTRY_POINT, + "SAGEMAKER_SUBMIT_DIRECTORY": "s3://mybucket/mi-2017-10-10-14-14-15/sourcedir.tar.gz", + "SAGEMAKER_CONTAINER_LOG_LEVEL": "20", + "SAGEMAKER_REGION": REGION, + "SAGEMAKER_ENABLE_CLOUDWATCH_METRICS": "false", + }, + "Image": MODEL_IMAGE, + "ModelDataUrl": MODEL_DATA, + } + + +@patch("shutil.rmtree", MagicMock()) +@patch("tarfile.open", MagicMock()) +@patch("os.listdir", MagicMock(return_value=["blah.py"])) +@patch("time.strftime", return_value=TIMESTAMP) +def test_prepare_container_def_with_network_isolation(time, sagemaker_session): + model = DummyFrameworkModel(sagemaker_session, enable_network_isolation=True) + assert model.prepare_container_def(INSTANCE_TYPE) == { + "Environment": { + "SAGEMAKER_PROGRAM": ENTRY_POINT, + "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", + "SAGEMAKER_CONTAINER_LOG_LEVEL": "20", + "SAGEMAKER_REGION": REGION, + "SAGEMAKER_ENABLE_CLOUDWATCH_METRICS": "false", + }, + "Image": MODEL_IMAGE, + "ModelDataUrl": MODEL_DATA, + } + + +@patch("shutil.rmtree", MagicMock()) +@patch("tarfile.open", MagicMock()) +@patch("os.path.exists", MagicMock(return_value=True)) +@patch("os.path.isdir", MagicMock(return_value=True)) +@patch("os.listdir", MagicMock(return_value=["blah.py"])) +@patch("time.strftime", MagicMock(return_value=TIMESTAMP)) +def test_create_no_defaults(sagemaker_session, tmpdir): + model = DummyFrameworkModel( + sagemaker_session, + source_dir="sd", + env={"a": "a"}, + name="name", + enable_cloudwatch_metrics=True, + container_log_level=55, + code_location="s3://cb/cp", + ) + + assert model.prepare_container_def(INSTANCE_TYPE) == { + "Environment": { + "SAGEMAKER_PROGRAM": ENTRY_POINT, + "SAGEMAKER_SUBMIT_DIRECTORY": "s3://cb/cp/name/sourcedir.tar.gz", + "SAGEMAKER_CONTAINER_LOG_LEVEL": "55", + "SAGEMAKER_REGION": REGION, + "SAGEMAKER_ENABLE_CLOUDWATCH_METRICS": "true", + "a": "a", + }, + "Image": MODEL_IMAGE, + "ModelDataUrl": MODEL_DATA, + } + + +@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) +@patch("time.strftime", MagicMock(return_value=TIMESTAMP)) +def test_deploy(sagemaker_session, tmpdir): + model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir)) + model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1) + sagemaker_session.endpoint_from_production_variants.assert_called_with( + name=MODEL_NAME, + production_variants=[ + { + "InitialVariantWeight": 1, + "ModelName": MODEL_NAME, + "InstanceType": INSTANCE_TYPE, + "InitialInstanceCount": 1, + "VariantName": "AllTraffic", + } + ], + tags=None, + kms_key=None, + wait=True, + data_capture_config_dict=None, + ) + + +@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) +@patch("time.strftime", MagicMock(return_value=TIMESTAMP)) +def test_deploy_endpoint_name(sagemaker_session, tmpdir): + model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir)) + model.deploy(endpoint_name="blah", instance_type=INSTANCE_TYPE, initial_instance_count=55) + sagemaker_session.endpoint_from_production_variants.assert_called_with( + name="blah", + production_variants=[ + { + "InitialVariantWeight": 1, + "ModelName": MODEL_NAME, + "InstanceType": INSTANCE_TYPE, + "InitialInstanceCount": 55, + "VariantName": "AllTraffic", + } + ], + tags=None, + kms_key=None, + wait=True, + data_capture_config_dict=None, + ) + + +@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) +@patch("time.strftime", MagicMock(return_value=TIMESTAMP)) +def test_deploy_tags(sagemaker_session, tmpdir): + model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir)) + tags = [{"ModelName": "TestModel"}] + model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1, tags=tags) + sagemaker_session.endpoint_from_production_variants.assert_called_with( + name=MODEL_NAME, + production_variants=[ + { + "InitialVariantWeight": 1, + "ModelName": MODEL_NAME, + "InstanceType": INSTANCE_TYPE, + "InitialInstanceCount": 1, + "VariantName": "AllTraffic", + } + ], + tags=tags, + kms_key=None, + wait=True, + data_capture_config_dict=None, + ) + + +@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) +@patch("tarfile.open") +@patch("time.strftime", return_value=TIMESTAMP) +def test_deploy_accelerator_type(tfo, time, sagemaker_session): + model = DummyFrameworkModel(sagemaker_session) + model.deploy( + instance_type=INSTANCE_TYPE, initial_instance_count=1, accelerator_type=ACCELERATOR_TYPE + ) + sagemaker_session.endpoint_from_production_variants.assert_called_with( + name=MODEL_NAME, + production_variants=[ + { + "InitialVariantWeight": 1, + "ModelName": MODEL_NAME, + "InstanceType": INSTANCE_TYPE, + "InitialInstanceCount": 1, + "VariantName": "AllTraffic", + "AcceleratorType": ACCELERATOR_TYPE, + } + ], + tags=None, + kms_key=None, + wait=True, + data_capture_config_dict=None, + ) + + +@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) +@patch("tarfile.open") +@patch("time.strftime", return_value=TIMESTAMP) +def test_deploy_kms_key(tfo, time, sagemaker_session): + key = "some-key-arn" + model = DummyFrameworkModel(sagemaker_session) + model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1, kms_key=key) + sagemaker_session.endpoint_from_production_variants.assert_called_with( + name=MODEL_NAME, + production_variants=[ + { + "InitialVariantWeight": 1, + "ModelName": MODEL_NAME, + "InstanceType": INSTANCE_TYPE, + "InitialInstanceCount": 1, + "VariantName": "AllTraffic", + } + ], + tags=None, + kms_key=key, + wait=True, + data_capture_config_dict=None, + ) + + +@patch("sagemaker.session.Session") +@patch("sagemaker.local.LocalSession") +@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) +def test_deploy_creates_correct_session(local_session, session, tmpdir): + # We expect a LocalSession when deploying to instance_type = 'local' + model = DummyFrameworkModel(sagemaker_session=None, source_dir=str(tmpdir)) + model.deploy(endpoint_name="blah", instance_type="local", initial_instance_count=1) + assert model.sagemaker_session == local_session.return_value + + # We expect a real Session when deploying to instance_type != local/local_gpu + model = DummyFrameworkModel(sagemaker_session=None, source_dir=str(tmpdir)) + model.deploy( + endpoint_name="remote_endpoint", instance_type="ml.m4.4xlarge", initial_instance_count=2 + ) + assert model.sagemaker_session == session.return_value + + +@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) +def test_deploy_update_endpoint(sagemaker_session, tmpdir): + model = DummyFrameworkModel(sagemaker_session, source_dir=tmpdir) + model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1, update_endpoint=True) + sagemaker_session.create_endpoint_config.assert_called_with( + name=model.name, + model_name=model.name, + initial_instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + accelerator_type=None, + tags=None, + kms_key=None, + data_capture_config_dict=None, + ) + config_name = sagemaker_session.create_endpoint_config( + name=model.name, + model_name=model.name, + initial_instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + accelerator_type=ACCELERATOR_TYPE, + ) + sagemaker_session.update_endpoint.assert_called_with(model.name, config_name, wait=True) + sagemaker_session.create_endpoint.assert_not_called() + + +@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) +def test_deploy_update_endpoint_optional_args(sagemaker_session, tmpdir): + endpoint_name = "endpoint-name" + tags = [{"Key": "Value"}] + kms_key = "foo" + data_capture_config = MagicMock() + + model = DummyFrameworkModel(sagemaker_session, source_dir=tmpdir) + model.deploy( + instance_type=INSTANCE_TYPE, + initial_instance_count=1, + update_endpoint=True, + endpoint_name=endpoint_name, + accelerator_type=ACCELERATOR_TYPE, + tags=tags, + kms_key=kms_key, + wait=False, + data_capture_config=data_capture_config, + ) + sagemaker_session.create_endpoint_config.assert_called_with( + name=model.name, + model_name=model.name, + initial_instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + accelerator_type=ACCELERATOR_TYPE, + tags=tags, + kms_key=kms_key, + data_capture_config_dict=data_capture_config._to_request_dict(), + ) + config_name = sagemaker_session.create_endpoint_config( + name=model.name, + model_name=model.name, + initial_instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + accelerator_type=ACCELERATOR_TYPE, + wait=False, + ) + sagemaker_session.update_endpoint.assert_called_with(endpoint_name, config_name, wait=False) + sagemaker_session.create_endpoint.assert_not_called() + + +def test_model_enable_network_isolation(sagemaker_session): + model = DummyFrameworkModel(sagemaker_session=sagemaker_session) + assert model.enable_network_isolation() is False + + +@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) +@patch("time.strftime", MagicMock(return_value=TIMESTAMP)) +def test_model_delete_model(sagemaker_session, tmpdir): + model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir)) + model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1) + model.delete_model() + + sagemaker_session.delete_model.assert_called_with(model.name) + + +def test_delete_non_deployed_model(sagemaker_session): + model = DummyFrameworkModel(sagemaker_session) + with pytest.raises( + ValueError, match="The SageMaker model must be created first before attempting to delete." + ): + model.delete_model() + + +def test_compile_model_for_inferentia(sagemaker_session, tmpdir): + sagemaker_session.wait_for_compilation_job = Mock( + return_value=DESCRIBE_COMPILATION_JOB_RESPONSE + ) + model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir)) + model.compile( + target_instance_family="ml_inf", + input_shape={"data": [1, 3, 1024, 1024]}, + output_path="s3://output", + role="role", + framework="tensorflow", + framework_version="1.15.0", + job_name="compile-model", + ) + assert ( + "{}.dkr.ecr.{}.amazonaws.com/sagemaker-neo-tensorflow:1.15.0-inf-py3".format( + NEO_REGION_ACCOUNT, REGION + ) + == model.image + ) + assert model._is_compiled_model is True + + +def test_compile_model_for_edge_device(sagemaker_session, tmpdir): + sagemaker_session.wait_for_compilation_job = Mock( + return_value=DESCRIBE_COMPILATION_JOB_RESPONSE + ) + model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir)) + model.compile( + target_instance_family="deeplens", + input_shape={"data": [1, 3, 1024, 1024]}, + output_path="s3://output", + role="role", + framework="tensorflow", + job_name="compile-model", + ) + assert model._is_compiled_model is False + + +def test_compile_model_for_edge_device_tflite(sagemaker_session, tmpdir): + sagemaker_session.wait_for_compilation_job = Mock( + return_value=DESCRIBE_COMPILATION_JOB_RESPONSE + ) + model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir)) + model.compile( + target_instance_family="deeplens", + input_shape={"data": [1, 3, 1024, 1024]}, + output_path="s3://output", + role="role", + framework="tflite", + job_name="tflite-compile-model", + ) + assert model._is_compiled_model is False + + +def test_compile_model_for_cloud(sagemaker_session, tmpdir): + sagemaker_session.wait_for_compilation_job = Mock( + return_value=DESCRIBE_COMPILATION_JOB_RESPONSE + ) + model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir)) + model.compile( + target_instance_family="ml_c4", + input_shape={"data": [1, 3, 1024, 1024]}, + output_path="s3://output", + role="role", + framework="tensorflow", + job_name="compile-model", + ) + assert model._is_compiled_model is True + + +def test_compile_model_for_cloud_tflite(sagemaker_session, tmpdir): + sagemaker_session.wait_for_compilation_job = Mock( + return_value=DESCRIBE_COMPILATION_JOB_RESPONSE + ) + model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir)) + model.compile( + target_instance_family="ml_c4", + input_shape={"data": [1, 3, 1024, 1024]}, + output_path="s3://output", + role="role", + framework="tflite", + job_name="tflite-compile-model", + ) + assert model._is_compiled_model is True + + +@patch("sagemaker.session.Session") +@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) +def test_compile_creates_session(session): + session.return_value.boto_region_name = "us-west-2" + + model = DummyFrameworkModel(sagemaker_session=None) + model.compile( + target_instance_family="ml_c4", + input_shape={"data": [1, 3, 1024, 1024]}, + output_path="s3://output", + role="role", + framework="tensorflow", + job_name="compile-model", + ) + + assert model.sagemaker_session == session.return_value + + +def test_check_neo_region(sagemaker_session, tmpdir): + sagemaker_session.wait_for_compilation_job = Mock( + return_value=DESCRIBE_COMPILATION_JOB_RESPONSE + ) + model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir)) + ec2_region_list = [ + "us-east-2", + "us-east-1", + "us-west-1", + "us-west-2", + "ap-east-1", + "ap-south-1", + "ap-northeast-3", + "ap-northeast-2", + "ap-southeast-1", + "ap-southeast-2", + "ap-northeast-1", + "ca-central-1", + "cn-north-1", + "cn-northwest-1", + "eu-central-1", + "eu-west-1", + "eu-west-2", + "eu-west-3", + "eu-north-1", + "sa-east-1", + "us-gov-east-1", + "us-gov-west-1", + ] + neo_support_region = [ + "us-west-1", + "us-west-2", + "us-east-1", + "us-east-2", + "eu-west-1", + "eu-west-2", + "eu-west-3", + "eu-central-1", + "eu-north-1", + "ap-northeast-1", + "ap-northeast-2", + "ap-east-1", + "ap-south-1", + "ap-southeast-1", + "ap-southeast-2", + "sa-east-1", + "ca-central-1", + "me-south-1", + "cn-north-1", + "cn-northwest-1", + "us-gov-west-1", + ] + for region_name in ec2_region_list: + if region_name in neo_support_region: + assert model.check_neo_region(region_name) is True + else: + assert model.check_neo_region(region_name) is False + + +@patch("sagemaker.git_utils.git_clone_repo") +@patch("sagemaker.model.fw_utils.tar_and_upload_dir") +def test_git_support_succeed(tar_and_upload_dir, git_clone_repo, sagemaker_session): + git_clone_repo.side_effect = lambda gitconfig, entrypoint, sourcedir, dependency: { + "entry_point": "entry_point", + "source_dir": "/tmp/repo_dir/source_dir", + "dependencies": ["/tmp/repo_dir/foo", "/tmp/repo_dir/bar"], + } + entry_point = "entry_point" + source_dir = "source_dir" + dependencies = ["foo", "bar"] + git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} + model = DummyFrameworkModelForGit( + sagemaker_session=sagemaker_session, + entry_point=entry_point, + source_dir=source_dir, + dependencies=dependencies, + git_config=git_config, + ) + model.prepare_container_def(instance_type=INSTANCE_TYPE) + git_clone_repo.assert_called_with(git_config, entry_point, source_dir, dependencies) + assert model.entry_point == "entry_point" + assert model.source_dir == "/tmp/repo_dir/source_dir" + assert model.dependencies == ["/tmp/repo_dir/foo", "/tmp/repo_dir/bar"] + + +def test_git_support_repo_not_provided(sagemaker_session): + entry_point = "source_dir/entry_point" + git_config = {"branch": BRANCH, "commit": COMMIT} + with pytest.raises(ValueError) as error: + model = DummyFrameworkModelForGit( + sagemaker_session=sagemaker_session, entry_point=entry_point, git_config=git_config + ) + model.prepare_container_def(instance_type=INSTANCE_TYPE) + assert "Please provide a repo for git_config." in str(error) + + +@patch( + "sagemaker.git_utils.git_clone_repo", + side_effect=subprocess.CalledProcessError( + returncode=1, cmd="git clone https://github.com/aws/no-such-repo.git /tmp/repo_dir" + ), +) +def test_git_support_git_clone_fail(sagemaker_session): + entry_point = "source_dir/entry_point" + git_config = {"repo": "https://github.com/aws/no-such-repo.git", "branch": BRANCH} + with pytest.raises(subprocess.CalledProcessError) as error: + model = DummyFrameworkModelForGit( + sagemaker_session=sagemaker_session, entry_point=entry_point, git_config=git_config + ) + model.prepare_container_def(instance_type=INSTANCE_TYPE) + assert "returned non-zero exit status" in str(error) + + +@patch( + "sagemaker.git_utils.git_clone_repo", + side_effect=subprocess.CalledProcessError( + returncode=1, cmd="git checkout branch-that-does-not-exist" + ), +) +def test_git_support_branch_not_exist(git_clone_repo, sagemaker_session): + entry_point = "source_dir/entry_point" + git_config = {"repo": GIT_REPO, "branch": "branch-that-does-not-exist", "commit": COMMIT} + with pytest.raises(subprocess.CalledProcessError) as error: + model = DummyFrameworkModelForGit( + sagemaker_session=sagemaker_session, entry_point=entry_point, git_config=git_config + ) + model.prepare_container_def(instance_type=INSTANCE_TYPE) + assert "returned non-zero exit status" in str(error) + + +@patch( + "sagemaker.git_utils.git_clone_repo", + side_effect=subprocess.CalledProcessError( + returncode=1, cmd="git checkout commit-sha-that-does-not-exist" + ), +) +def test_git_support_commit_not_exist(git_clone_repo, sagemaker_session): + entry_point = "source_dir/entry_point" + git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": "commit-sha-that-does-not-exist"} + with pytest.raises(subprocess.CalledProcessError) as error: + model = DummyFrameworkModelForGit( + sagemaker_session=sagemaker_session, entry_point=entry_point, git_config=git_config + ) + model.prepare_container_def(instance_type=INSTANCE_TYPE) + assert "returned non-zero exit status" in str(error) + + +@patch( + "sagemaker.git_utils.git_clone_repo", + side_effect=ValueError("Entry point does not exist in the repo."), +) +def test_git_support_entry_point_not_exist(sagemaker_session): + entry_point = "source_dir/entry_point" + git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} + with pytest.raises(ValueError) as error: + model = DummyFrameworkModelForGit( + sagemaker_session=sagemaker_session, entry_point=entry_point, git_config=git_config + ) + model.prepare_container_def(instance_type=INSTANCE_TYPE) + assert "Entry point does not exist in the repo." in str(error) + + +@patch( + "sagemaker.git_utils.git_clone_repo", + side_effect=ValueError("Source directory does not exist in the repo."), +) +def test_git_support_source_dir_not_exist(sagemaker_session): + entry_point = "entry_point" + source_dir = "source_dir_that_does_not_exist" + git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} + with pytest.raises(ValueError) as error: + model = DummyFrameworkModelForGit( + sagemaker_session=sagemaker_session, + entry_point=entry_point, + source_dir=source_dir, + git_config=git_config, + ) + model.prepare_container_def(instance_type=INSTANCE_TYPE) + assert "Source directory does not exist in the repo." in str(error) + + +@patch( + "sagemaker.git_utils.git_clone_repo", + side_effect=ValueError("Dependency no-such-dir does not exist in the repo."), +) +def test_git_support_dependencies_not_exist(sagemaker_session): + entry_point = "entry_point" + dependencies = ["foo", "no_such_dir"] + git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} + with pytest.raises(ValueError) as error: + model = DummyFrameworkModelForGit( + sagemaker_session=sagemaker_session, + entry_point=entry_point, + dependencies=dependencies, + git_config=git_config, + ) + model.prepare_container_def(instance_type=INSTANCE_TYPE) + assert "Dependency", "does not exist in the repo." in str(error) + + +@patch( + "sagemaker.git_utils.git_clone_repo", + side_effect=lambda gitconfig, entrypoint, source_dir=None, dependencies=None: { + "entry_point": "/tmp/repo_dir/entry_point", + "source_dir": None, + "dependencies": None, + }, +) +@patch("sagemaker.model.fw_utils.tar_and_upload_dir") +def test_git_support_with_username_password_no_2fa( + tar_and_upload_dir, git_clone_repo, sagemaker_session +): + entry_point = "entry_point" + git_config = { + "repo": PRIVATE_GIT_REPO, + "branch": PRIVATE_BRANCH, + "commit": PRIVATE_COMMIT, + "username": "username", + "password": "passw0rd!", + } + model = DummyFrameworkModelForGit( + sagemaker_session=sagemaker_session, entry_point=entry_point, git_config=git_config + ) + model.prepare_container_def(instance_type=INSTANCE_TYPE) + git_clone_repo.assert_called_with(git_config, entry_point, None, []) + assert model.entry_point == "/tmp/repo_dir/entry_point" + + +@patch( + "sagemaker.git_utils.git_clone_repo", + side_effect=lambda gitconfig, entrypoint, source_dir=None, dependencies=None: { + "entry_point": "/tmp/repo_dir/entry_point", + "source_dir": None, + "dependencies": None, + }, +) +@patch("sagemaker.model.fw_utils.tar_and_upload_dir") +def test_git_support_with_token_2fa(tar_and_upload_dir, git_clone_repo, sagemaker_session): + entry_point = "entry_point" + git_config = { + "repo": PRIVATE_GIT_REPO, + "branch": PRIVATE_BRANCH, + "commit": PRIVATE_COMMIT, + "token": "my-token", + "2FA_enabled": True, + } + model = DummyFrameworkModelForGit( + sagemaker_session=sagemaker_session, entry_point=entry_point, git_config=git_config + ) + model.prepare_container_def(instance_type=INSTANCE_TYPE) + git_clone_repo.assert_called_with(git_config, entry_point, None, []) + assert model.entry_point == "/tmp/repo_dir/entry_point" + + +@patch( + "sagemaker.git_utils.git_clone_repo", + side_effect=lambda gitconfig, entrypoint, source_dir=None, dependencies=None: { + "entry_point": "/tmp/repo_dir/entry_point", + "source_dir": None, + "dependencies": None, + }, +) +@patch("sagemaker.model.fw_utils.tar_and_upload_dir") +def test_git_support_ssh_no_passphrase_needed( + tar_and_upload_dir, git_clone_repo, sagemaker_session +): + entry_point = "entry_point" + git_config = {"repo": PRIVATE_GIT_REPO_SSH, "branch": PRIVATE_BRANCH, "commit": PRIVATE_COMMIT} + model = DummyFrameworkModelForGit( + sagemaker_session=sagemaker_session, entry_point=entry_point, git_config=git_config + ) + model.prepare_container_def(instance_type=INSTANCE_TYPE) + git_clone_repo.assert_called_with(git_config, entry_point, None, []) + assert model.entry_point == "/tmp/repo_dir/entry_point" + + +@patch( + "sagemaker.git_utils.git_clone_repo", + side_effect=subprocess.CalledProcessError( + returncode=1, cmd="git clone {} {}".format(PRIVATE_GIT_REPO_SSH, REPO_DIR) + ), +) +@patch("sagemaker.model.fw_utils.tar_and_upload_dir") +def test_git_support_ssh_passphrase_required(tar_and_upload_dir, git_clone_repo, sagemaker_session): + entry_point = "entry_point" + git_config = {"repo": PRIVATE_GIT_REPO_SSH, "branch": PRIVATE_BRANCH, "commit": PRIVATE_COMMIT} + with pytest.raises(subprocess.CalledProcessError) as error: + model = DummyFrameworkModelForGit( + sagemaker_session=sagemaker_session, entry_point=entry_point, git_config=git_config + ) + model.prepare_container_def(instance_type=INSTANCE_TYPE) + assert "returned non-zero exit status" in str(error) + + +@patch( + "sagemaker.git_utils.git_clone_repo", + side_effect=lambda gitconfig, entrypoint, source_dir=None, dependencies=None: { + "entry_point": "/tmp/repo_dir/entry_point", + "source_dir": None, + "dependencies": None, + }, +) +@patch("sagemaker.model.fw_utils.tar_and_upload_dir") +def test_git_support_codecommit_with_username_and_password_succeed( + tar_and_upload_dir, git_clone_repo, sagemaker_session +): + entry_point = "entry_point" + git_config = { + "repo": CODECOMMIT_REPO, + "branch": CODECOMMIT_BRANCH, + "username": "username", + "password": "passw0rd!", + } + model = DummyFrameworkModelForGit( + sagemaker_session=sagemaker_session, entry_point=entry_point, git_config=git_config + ) + model.prepare_container_def(instance_type=INSTANCE_TYPE) + git_clone_repo.assert_called_with(git_config, entry_point, None, []) + assert model.entry_point == "/tmp/repo_dir/entry_point" + + +@patch( + "sagemaker.git_utils.git_clone_repo", + side_effect=lambda gitconfig, entrypoint, source_dir=None, dependencies=None: { + "entry_point": "/tmp/repo_dir/entry_point", + "source_dir": None, + "dependencies": None, + }, +) +@patch("sagemaker.model.fw_utils.tar_and_upload_dir") +def test_git_support_codecommit_ssh_no_passphrase_needed( + tar_and_upload_dir, git_clone_repo, sagemaker_session +): + entry_point = "entry_point" + git_config = {"repo": CODECOMMIT_REPO_SSH, "branch": CODECOMMIT_BRANCH} + model = DummyFrameworkModelForGit( + sagemaker_session=sagemaker_session, entry_point=entry_point, git_config=git_config + ) + model.prepare_container_def(instance_type=INSTANCE_TYPE) + git_clone_repo.assert_called_with(git_config, entry_point, None, []) + assert model.entry_point == "/tmp/repo_dir/entry_point" + + +@patch( + "sagemaker.git_utils.git_clone_repo", + side_effect=subprocess.CalledProcessError( + returncode=1, cmd="git clone {} {}".format(PRIVATE_GIT_REPO_SSH, REPO_DIR) + ), +) +@patch("sagemaker.model.fw_utils.tar_and_upload_dir") +def test_git_support_codecommit_ssh_passphrase_required( + tar_and_upload_dir, git_clone_repo, sagemaker_session +): + entry_point = "entry_point" + git_config = {"repo": CODECOMMIT_REPO_SSH, "branch": CODECOMMIT_BRANCH} + with pytest.raises(subprocess.CalledProcessError) as error: + model = DummyFrameworkModelForGit( + sagemaker_session=sagemaker_session, entry_point=entry_point, git_config=git_config + ) + model.prepare_container_def(instance_type=INSTANCE_TYPE) + assert "returned non-zero exit status" in str(error) diff --git a/tests/unit/sagemaker/model/test_model.py b/tests/unit/sagemaker/model/test_model.py index d46f7cacb7..84099cc43b 100644 --- a/tests/unit/sagemaker/model/test_model.py +++ b/tests/unit/sagemaker/model/test_model.py @@ -12,375 +12,13 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import -import os -import subprocess +from mock import Mock, patch import sagemaker -from sagemaker.model import FrameworkModel, Model -from sagemaker.predictor import RealTimePredictor - -import pytest -from mock import MagicMock, Mock, patch +from sagemaker.model import Model MODEL_DATA = "s3://bucket/model.tar.gz" MODEL_IMAGE = "mi" -ENTRY_POINT = "blah.py" -INSTANCE_TYPE = "p2.xlarge" -ROLE = "some-role" - -DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") -SCRIPT_NAME = "dummy_script.py" -SCRIPT_PATH = os.path.join(DATA_DIR, SCRIPT_NAME) -TIMESTAMP = "2017-10-10-14-14-15" -BUCKET_NAME = "mybucket" -INSTANCE_COUNT = 1 -INSTANCE_TYPE = "c4.4xlarge" -ACCELERATOR_TYPE = "ml.eia.medium" -IMAGE_NAME = "fakeimage" -REGION = "us-west-2" -NEO_REGION_ACCOUNT = "301217895009" -MODEL_NAME = "{}-{}".format(MODEL_IMAGE, TIMESTAMP) -GIT_REPO = "https://github.com/aws/sagemaker-python-sdk.git" -BRANCH = "test-branch-git-config" -COMMIT = "ae15c9d7d5b97ea95ea451e4662ee43da3401d73" -PRIVATE_GIT_REPO_SSH = "git@github.com:testAccount/private-repo.git" -PRIVATE_GIT_REPO = "https://github.com/testAccount/private-repo.git" -PRIVATE_BRANCH = "test-branch" -PRIVATE_COMMIT = "329bfcf884482002c05ff7f44f62599ebc9f445a" -CODECOMMIT_REPO = "https://git-codecommit.us-west-2.amazonaws.com/v1/repos/test-repo/" -CODECOMMIT_REPO_SSH = "ssh://git-codecommit.us-west-2.amazonaws.com/v1/repos/test-repo/" -CODECOMMIT_BRANCH = "master" -REPO_DIR = "/tmp/repo_dir" - -DESCRIBE_COMPILATION_JOB_RESPONSE = { - "CompilationJobStatus": "Completed", - "ModelArtifacts": {"S3ModelArtifacts": "s3://output-path/model.tar.gz"}, -} - - -class DummyFrameworkModel(FrameworkModel): - def __init__(self, sagemaker_session, **kwargs): - super(DummyFrameworkModel, self).__init__( - MODEL_DATA, - MODEL_IMAGE, - ROLE, - ENTRY_POINT, - sagemaker_session=sagemaker_session, - **kwargs - ) - - def create_predictor(self, endpoint_name): - return RealTimePredictor(endpoint_name, sagemaker_session=self.sagemaker_session) - - -class DummyFrameworkModelForGit(FrameworkModel): - def __init__(self, sagemaker_session, entry_point, **kwargs): - super(DummyFrameworkModelForGit, self).__init__( - MODEL_DATA, - MODEL_IMAGE, - ROLE, - entry_point=entry_point, - sagemaker_session=sagemaker_session, - **kwargs - ) - - def create_predictor(self, endpoint_name): - return RealTimePredictor(endpoint_name, sagemaker_session=self.sagemaker_session) - - -@pytest.fixture() -def sagemaker_session(): - boto_mock = Mock(name="boto_session", region_name=REGION) - sms = Mock( - name="sagemaker_session", - boto_session=boto_mock, - boto_region_name=REGION, - config=None, - local_mode=False, - s3_client=None, - s3_resource=None, - ) - sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) - return sms - - -@patch("shutil.rmtree", MagicMock()) -@patch("tarfile.open", MagicMock()) -@patch("os.listdir", MagicMock(return_value=["blah.py"])) -@patch("time.strftime", return_value=TIMESTAMP) -def test_prepare_container_def(time, sagemaker_session): - model = DummyFrameworkModel(sagemaker_session) - assert model.prepare_container_def(INSTANCE_TYPE) == { - "Environment": { - "SAGEMAKER_PROGRAM": ENTRY_POINT, - "SAGEMAKER_SUBMIT_DIRECTORY": "s3://mybucket/mi-2017-10-10-14-14-15/sourcedir.tar.gz", - "SAGEMAKER_CONTAINER_LOG_LEVEL": "20", - "SAGEMAKER_REGION": REGION, - "SAGEMAKER_ENABLE_CLOUDWATCH_METRICS": "false", - }, - "Image": MODEL_IMAGE, - "ModelDataUrl": MODEL_DATA, - } - - -@patch("shutil.rmtree", MagicMock()) -@patch("tarfile.open", MagicMock()) -@patch("os.listdir", MagicMock(return_value=["blah.py"])) -@patch("time.strftime", return_value=TIMESTAMP) -def test_prepare_container_def_with_network_isolation(time, sagemaker_session): - model = DummyFrameworkModel(sagemaker_session, enable_network_isolation=True) - assert model.prepare_container_def(INSTANCE_TYPE) == { - "Environment": { - "SAGEMAKER_PROGRAM": ENTRY_POINT, - "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", - "SAGEMAKER_CONTAINER_LOG_LEVEL": "20", - "SAGEMAKER_REGION": REGION, - "SAGEMAKER_ENABLE_CLOUDWATCH_METRICS": "false", - }, - "Image": MODEL_IMAGE, - "ModelDataUrl": MODEL_DATA, - } - - -@patch("shutil.rmtree", MagicMock()) -@patch("tarfile.open", MagicMock()) -@patch("os.path.exists", MagicMock(return_value=True)) -@patch("os.path.isdir", MagicMock(return_value=True)) -@patch("os.listdir", MagicMock(return_value=["blah.py"])) -@patch("time.strftime", MagicMock(return_value=TIMESTAMP)) -def test_create_no_defaults(sagemaker_session, tmpdir): - model = DummyFrameworkModel( - sagemaker_session, - source_dir="sd", - env={"a": "a"}, - name="name", - enable_cloudwatch_metrics=True, - container_log_level=55, - code_location="s3://cb/cp", - ) - - assert model.prepare_container_def(INSTANCE_TYPE) == { - "Environment": { - "SAGEMAKER_PROGRAM": ENTRY_POINT, - "SAGEMAKER_SUBMIT_DIRECTORY": "s3://cb/cp/name/sourcedir.tar.gz", - "SAGEMAKER_CONTAINER_LOG_LEVEL": "55", - "SAGEMAKER_REGION": REGION, - "SAGEMAKER_ENABLE_CLOUDWATCH_METRICS": "true", - "a": "a", - }, - "Image": MODEL_IMAGE, - "ModelDataUrl": MODEL_DATA, - } - - -@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) -@patch("time.strftime", MagicMock(return_value=TIMESTAMP)) -def test_deploy(sagemaker_session, tmpdir): - model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir)) - model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1) - sagemaker_session.endpoint_from_production_variants.assert_called_with( - name=MODEL_NAME, - production_variants=[ - { - "InitialVariantWeight": 1, - "ModelName": MODEL_NAME, - "InstanceType": INSTANCE_TYPE, - "InitialInstanceCount": 1, - "VariantName": "AllTraffic", - } - ], - tags=None, - kms_key=None, - wait=True, - data_capture_config_dict=None, - ) - - -@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) -@patch("time.strftime", MagicMock(return_value=TIMESTAMP)) -def test_deploy_endpoint_name(sagemaker_session, tmpdir): - model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir)) - model.deploy(endpoint_name="blah", instance_type=INSTANCE_TYPE, initial_instance_count=55) - sagemaker_session.endpoint_from_production_variants.assert_called_with( - name="blah", - production_variants=[ - { - "InitialVariantWeight": 1, - "ModelName": MODEL_NAME, - "InstanceType": INSTANCE_TYPE, - "InitialInstanceCount": 55, - "VariantName": "AllTraffic", - } - ], - tags=None, - kms_key=None, - wait=True, - data_capture_config_dict=None, - ) - - -@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) -@patch("time.strftime", MagicMock(return_value=TIMESTAMP)) -def test_deploy_tags(sagemaker_session, tmpdir): - model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir)) - tags = [{"ModelName": "TestModel"}] - model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1, tags=tags) - sagemaker_session.endpoint_from_production_variants.assert_called_with( - name=MODEL_NAME, - production_variants=[ - { - "InitialVariantWeight": 1, - "ModelName": MODEL_NAME, - "InstanceType": INSTANCE_TYPE, - "InitialInstanceCount": 1, - "VariantName": "AllTraffic", - } - ], - tags=tags, - kms_key=None, - wait=True, - data_capture_config_dict=None, - ) - - -@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) -@patch("tarfile.open") -@patch("time.strftime", return_value=TIMESTAMP) -def test_deploy_accelerator_type(tfo, time, sagemaker_session): - model = DummyFrameworkModel(sagemaker_session) - model.deploy( - instance_type=INSTANCE_TYPE, initial_instance_count=1, accelerator_type=ACCELERATOR_TYPE - ) - sagemaker_session.endpoint_from_production_variants.assert_called_with( - name=MODEL_NAME, - production_variants=[ - { - "InitialVariantWeight": 1, - "ModelName": MODEL_NAME, - "InstanceType": INSTANCE_TYPE, - "InitialInstanceCount": 1, - "VariantName": "AllTraffic", - "AcceleratorType": ACCELERATOR_TYPE, - } - ], - tags=None, - kms_key=None, - wait=True, - data_capture_config_dict=None, - ) - - -@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) -@patch("tarfile.open") -@patch("time.strftime", return_value=TIMESTAMP) -def test_deploy_kms_key(tfo, time, sagemaker_session): - key = "some-key-arn" - model = DummyFrameworkModel(sagemaker_session) - model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1, kms_key=key) - sagemaker_session.endpoint_from_production_variants.assert_called_with( - name=MODEL_NAME, - production_variants=[ - { - "InitialVariantWeight": 1, - "ModelName": MODEL_NAME, - "InstanceType": INSTANCE_TYPE, - "InitialInstanceCount": 1, - "VariantName": "AllTraffic", - } - ], - tags=None, - kms_key=key, - wait=True, - data_capture_config_dict=None, - ) - - -@patch("sagemaker.session.Session") -@patch("sagemaker.local.LocalSession") -@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) -def test_deploy_creates_correct_session(local_session, session, tmpdir): - # We expect a LocalSession when deploying to instance_type = 'local' - model = DummyFrameworkModel(sagemaker_session=None, source_dir=str(tmpdir)) - model.deploy(endpoint_name="blah", instance_type="local", initial_instance_count=1) - assert model.sagemaker_session == local_session.return_value - - # We expect a real Session when deploying to instance_type != local/local_gpu - model = DummyFrameworkModel(sagemaker_session=None, source_dir=str(tmpdir)) - model.deploy( - endpoint_name="remote_endpoint", instance_type="ml.m4.4xlarge", initial_instance_count=2 - ) - assert model.sagemaker_session == session.return_value - - -@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) -def test_deploy_update_endpoint(sagemaker_session, tmpdir): - model = DummyFrameworkModel(sagemaker_session, source_dir=tmpdir) - model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1, update_endpoint=True) - sagemaker_session.create_endpoint_config.assert_called_with( - name=model.name, - model_name=model.name, - initial_instance_count=INSTANCE_COUNT, - instance_type=INSTANCE_TYPE, - accelerator_type=None, - tags=None, - kms_key=None, - data_capture_config_dict=None, - ) - config_name = sagemaker_session.create_endpoint_config( - name=model.name, - model_name=model.name, - initial_instance_count=INSTANCE_COUNT, - instance_type=INSTANCE_TYPE, - accelerator_type=ACCELERATOR_TYPE, - ) - sagemaker_session.update_endpoint.assert_called_with(model.name, config_name, wait=True) - sagemaker_session.create_endpoint.assert_not_called() - - -@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) -def test_deploy_update_endpoint_optional_args(sagemaker_session, tmpdir): - endpoint_name = "endpoint-name" - tags = [{"Key": "Value"}] - kms_key = "foo" - data_capture_config = MagicMock() - - model = DummyFrameworkModel(sagemaker_session, source_dir=tmpdir) - model.deploy( - instance_type=INSTANCE_TYPE, - initial_instance_count=1, - update_endpoint=True, - endpoint_name=endpoint_name, - accelerator_type=ACCELERATOR_TYPE, - tags=tags, - kms_key=kms_key, - wait=False, - data_capture_config=data_capture_config, - ) - sagemaker_session.create_endpoint_config.assert_called_with( - name=model.name, - model_name=model.name, - initial_instance_count=INSTANCE_COUNT, - instance_type=INSTANCE_TYPE, - accelerator_type=ACCELERATOR_TYPE, - tags=tags, - kms_key=kms_key, - data_capture_config_dict=data_capture_config._to_request_dict(), - ) - config_name = sagemaker_session.create_endpoint_config( - name=model.name, - model_name=model.name, - initial_instance_count=INSTANCE_COUNT, - instance_type=INSTANCE_TYPE, - accelerator_type=ACCELERATOR_TYPE, - wait=False, - ) - sagemaker_session.update_endpoint.assert_called_with(endpoint_name, config_name, wait=False) - sagemaker_session.create_endpoint.assert_not_called() - - -def test_model_enable_network_isolation(sagemaker_session): - model = DummyFrameworkModel(sagemaker_session=sagemaker_session) - assert model.enable_network_isolation() is False @patch("sagemaker.model.Model._create_sagemaker_model") @@ -469,7 +107,7 @@ def test_model_create_transformer_network_isolation(create_sagemaker_model, sage @patch("sagemaker.session.Session") @patch("sagemaker.local.LocalSession") -@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) +@patch("sagemaker.fw_utils.tar_and_upload_dir", Mock()) def test_transformer_creates_correct_session(local_session, session): model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=None) transformer = model.transformer(instance_count=1, instance_type="local") @@ -480,489 +118,3 @@ def test_transformer_creates_correct_session(local_session, session): transformer = model.transformer(instance_count=1, instance_type="ml.m5.xlarge") assert model.sagemaker_session == session.return_value assert transformer.sagemaker_session == session.return_value - - -@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) -@patch("time.strftime", MagicMock(return_value=TIMESTAMP)) -def test_model_delete_model(sagemaker_session, tmpdir): - model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir)) - model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1) - model.delete_model() - - sagemaker_session.delete_model.assert_called_with(model.name) - - -def test_delete_non_deployed_model(sagemaker_session): - model = DummyFrameworkModel(sagemaker_session) - with pytest.raises( - ValueError, match="The SageMaker model must be created first before attempting to delete." - ): - model.delete_model() - - -def test_compile_model_for_inferentia(sagemaker_session, tmpdir): - sagemaker_session.wait_for_compilation_job = Mock( - return_value=DESCRIBE_COMPILATION_JOB_RESPONSE - ) - model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir)) - model.compile( - target_instance_family="ml_inf", - input_shape={"data": [1, 3, 1024, 1024]}, - output_path="s3://output", - role="role", - framework="tensorflow", - framework_version="1.15.0", - job_name="compile-model", - ) - assert ( - "{}.dkr.ecr.{}.amazonaws.com/sagemaker-neo-tensorflow:1.15.0-inf-py3".format( - NEO_REGION_ACCOUNT, REGION - ) - == model.image - ) - assert model._is_compiled_model is True - - -def test_compile_model_for_edge_device(sagemaker_session, tmpdir): - sagemaker_session.wait_for_compilation_job = Mock( - return_value=DESCRIBE_COMPILATION_JOB_RESPONSE - ) - model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir)) - model.compile( - target_instance_family="deeplens", - input_shape={"data": [1, 3, 1024, 1024]}, - output_path="s3://output", - role="role", - framework="tensorflow", - job_name="compile-model", - ) - assert model._is_compiled_model is False - - -def test_compile_model_for_edge_device_tflite(sagemaker_session, tmpdir): - sagemaker_session.wait_for_compilation_job = Mock( - return_value=DESCRIBE_COMPILATION_JOB_RESPONSE - ) - model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir)) - model.compile( - target_instance_family="deeplens", - input_shape={"data": [1, 3, 1024, 1024]}, - output_path="s3://output", - role="role", - framework="tflite", - job_name="tflite-compile-model", - ) - assert model._is_compiled_model is False - - -def test_compile_model_for_cloud(sagemaker_session, tmpdir): - sagemaker_session.wait_for_compilation_job = Mock( - return_value=DESCRIBE_COMPILATION_JOB_RESPONSE - ) - model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir)) - model.compile( - target_instance_family="ml_c4", - input_shape={"data": [1, 3, 1024, 1024]}, - output_path="s3://output", - role="role", - framework="tensorflow", - job_name="compile-model", - ) - assert model._is_compiled_model is True - - -def test_compile_model_for_cloud_tflite(sagemaker_session, tmpdir): - sagemaker_session.wait_for_compilation_job = Mock( - return_value=DESCRIBE_COMPILATION_JOB_RESPONSE - ) - model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir)) - model.compile( - target_instance_family="ml_c4", - input_shape={"data": [1, 3, 1024, 1024]}, - output_path="s3://output", - role="role", - framework="tflite", - job_name="tflite-compile-model", - ) - assert model._is_compiled_model is True - - -@patch("sagemaker.session.Session") -@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) -def test_compile_creates_session(session): - session.return_value.boto_region_name = "us-west-2" - - model = DummyFrameworkModel(sagemaker_session=None) - model.compile( - target_instance_family="ml_c4", - input_shape={"data": [1, 3, 1024, 1024]}, - output_path="s3://output", - role="role", - framework="tensorflow", - job_name="compile-model", - ) - - assert model.sagemaker_session == session.return_value - - -def test_check_neo_region(sagemaker_session, tmpdir): - sagemaker_session.wait_for_compilation_job = Mock( - return_value=DESCRIBE_COMPILATION_JOB_RESPONSE - ) - model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir)) - ec2_region_list = [ - "us-east-2", - "us-east-1", - "us-west-1", - "us-west-2", - "ap-east-1", - "ap-south-1", - "ap-northeast-3", - "ap-northeast-2", - "ap-southeast-1", - "ap-southeast-2", - "ap-northeast-1", - "ca-central-1", - "cn-north-1", - "cn-northwest-1", - "eu-central-1", - "eu-west-1", - "eu-west-2", - "eu-west-3", - "eu-north-1", - "sa-east-1", - "us-gov-east-1", - "us-gov-west-1", - ] - neo_support_region = [ - "us-west-1", - "us-west-2", - "us-east-1", - "us-east-2", - "eu-west-1", - "eu-west-2", - "eu-west-3", - "eu-central-1", - "eu-north-1", - "ap-northeast-1", - "ap-northeast-2", - "ap-east-1", - "ap-south-1", - "ap-southeast-1", - "ap-southeast-2", - "sa-east-1", - "ca-central-1", - "me-south-1", - "cn-north-1", - "cn-northwest-1", - "us-gov-west-1", - ] - for region_name in ec2_region_list: - if region_name in neo_support_region: - assert model.check_neo_region(region_name) is True - else: - assert model.check_neo_region(region_name) is False - - -@patch("sagemaker.git_utils.git_clone_repo") -@patch("sagemaker.model.fw_utils.tar_and_upload_dir") -def test_git_support_succeed(tar_and_upload_dir, git_clone_repo, sagemaker_session): - git_clone_repo.side_effect = lambda gitconfig, entrypoint, sourcedir, dependency: { - "entry_point": "entry_point", - "source_dir": "/tmp/repo_dir/source_dir", - "dependencies": ["/tmp/repo_dir/foo", "/tmp/repo_dir/bar"], - } - entry_point = "entry_point" - source_dir = "source_dir" - dependencies = ["foo", "bar"] - git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} - model = DummyFrameworkModelForGit( - sagemaker_session=sagemaker_session, - entry_point=entry_point, - source_dir=source_dir, - dependencies=dependencies, - git_config=git_config, - ) - model.prepare_container_def(instance_type=INSTANCE_TYPE) - git_clone_repo.assert_called_with(git_config, entry_point, source_dir, dependencies) - assert model.entry_point == "entry_point" - assert model.source_dir == "/tmp/repo_dir/source_dir" - assert model.dependencies == ["/tmp/repo_dir/foo", "/tmp/repo_dir/bar"] - - -def test_git_support_repo_not_provided(sagemaker_session): - entry_point = "source_dir/entry_point" - git_config = {"branch": BRANCH, "commit": COMMIT} - with pytest.raises(ValueError) as error: - model = DummyFrameworkModelForGit( - sagemaker_session=sagemaker_session, entry_point=entry_point, git_config=git_config - ) - model.prepare_container_def(instance_type=INSTANCE_TYPE) - assert "Please provide a repo for git_config." in str(error) - - -@patch( - "sagemaker.git_utils.git_clone_repo", - side_effect=subprocess.CalledProcessError( - returncode=1, cmd="git clone https://github.com/aws/no-such-repo.git /tmp/repo_dir" - ), -) -def test_git_support_git_clone_fail(sagemaker_session): - entry_point = "source_dir/entry_point" - git_config = {"repo": "https://github.com/aws/no-such-repo.git", "branch": BRANCH} - with pytest.raises(subprocess.CalledProcessError) as error: - model = DummyFrameworkModelForGit( - sagemaker_session=sagemaker_session, entry_point=entry_point, git_config=git_config - ) - model.prepare_container_def(instance_type=INSTANCE_TYPE) - assert "returned non-zero exit status" in str(error) - - -@patch( - "sagemaker.git_utils.git_clone_repo", - side_effect=subprocess.CalledProcessError( - returncode=1, cmd="git checkout branch-that-does-not-exist" - ), -) -def test_git_support_branch_not_exist(git_clone_repo, sagemaker_session): - entry_point = "source_dir/entry_point" - git_config = {"repo": GIT_REPO, "branch": "branch-that-does-not-exist", "commit": COMMIT} - with pytest.raises(subprocess.CalledProcessError) as error: - model = DummyFrameworkModelForGit( - sagemaker_session=sagemaker_session, entry_point=entry_point, git_config=git_config - ) - model.prepare_container_def(instance_type=INSTANCE_TYPE) - assert "returned non-zero exit status" in str(error) - - -@patch( - "sagemaker.git_utils.git_clone_repo", - side_effect=subprocess.CalledProcessError( - returncode=1, cmd="git checkout commit-sha-that-does-not-exist" - ), -) -def test_git_support_commit_not_exist(git_clone_repo, sagemaker_session): - entry_point = "source_dir/entry_point" - git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": "commit-sha-that-does-not-exist"} - with pytest.raises(subprocess.CalledProcessError) as error: - model = DummyFrameworkModelForGit( - sagemaker_session=sagemaker_session, entry_point=entry_point, git_config=git_config - ) - model.prepare_container_def(instance_type=INSTANCE_TYPE) - assert "returned non-zero exit status" in str(error) - - -@patch( - "sagemaker.git_utils.git_clone_repo", - side_effect=ValueError("Entry point does not exist in the repo."), -) -def test_git_support_entry_point_not_exist(sagemaker_session): - entry_point = "source_dir/entry_point" - git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} - with pytest.raises(ValueError) as error: - model = DummyFrameworkModelForGit( - sagemaker_session=sagemaker_session, entry_point=entry_point, git_config=git_config - ) - model.prepare_container_def(instance_type=INSTANCE_TYPE) - assert "Entry point does not exist in the repo." in str(error) - - -@patch( - "sagemaker.git_utils.git_clone_repo", - side_effect=ValueError("Source directory does not exist in the repo."), -) -def test_git_support_source_dir_not_exist(sagemaker_session): - entry_point = "entry_point" - source_dir = "source_dir_that_does_not_exist" - git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} - with pytest.raises(ValueError) as error: - model = DummyFrameworkModelForGit( - sagemaker_session=sagemaker_session, - entry_point=entry_point, - source_dir=source_dir, - git_config=git_config, - ) - model.prepare_container_def(instance_type=INSTANCE_TYPE) - assert "Source directory does not exist in the repo." in str(error) - - -@patch( - "sagemaker.git_utils.git_clone_repo", - side_effect=ValueError("Dependency no-such-dir does not exist in the repo."), -) -def test_git_support_dependencies_not_exist(sagemaker_session): - entry_point = "entry_point" - dependencies = ["foo", "no_such_dir"] - git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} - with pytest.raises(ValueError) as error: - model = DummyFrameworkModelForGit( - sagemaker_session=sagemaker_session, - entry_point=entry_point, - dependencies=dependencies, - git_config=git_config, - ) - model.prepare_container_def(instance_type=INSTANCE_TYPE) - assert "Dependency", "does not exist in the repo." in str(error) - - -@patch( - "sagemaker.git_utils.git_clone_repo", - side_effect=lambda gitconfig, entrypoint, source_dir=None, dependencies=None: { - "entry_point": "/tmp/repo_dir/entry_point", - "source_dir": None, - "dependencies": None, - }, -) -@patch("sagemaker.model.fw_utils.tar_and_upload_dir") -def test_git_support_with_username_password_no_2fa( - tar_and_upload_dir, git_clone_repo, sagemaker_session -): - entry_point = "entry_point" - git_config = { - "repo": PRIVATE_GIT_REPO, - "branch": PRIVATE_BRANCH, - "commit": PRIVATE_COMMIT, - "username": "username", - "password": "passw0rd!", - } - model = DummyFrameworkModelForGit( - sagemaker_session=sagemaker_session, entry_point=entry_point, git_config=git_config - ) - model.prepare_container_def(instance_type=INSTANCE_TYPE) - git_clone_repo.assert_called_with(git_config, entry_point, None, []) - assert model.entry_point == "/tmp/repo_dir/entry_point" - - -@patch( - "sagemaker.git_utils.git_clone_repo", - side_effect=lambda gitconfig, entrypoint, source_dir=None, dependencies=None: { - "entry_point": "/tmp/repo_dir/entry_point", - "source_dir": None, - "dependencies": None, - }, -) -@patch("sagemaker.model.fw_utils.tar_and_upload_dir") -def test_git_support_with_token_2fa(tar_and_upload_dir, git_clone_repo, sagemaker_session): - entry_point = "entry_point" - git_config = { - "repo": PRIVATE_GIT_REPO, - "branch": PRIVATE_BRANCH, - "commit": PRIVATE_COMMIT, - "token": "my-token", - "2FA_enabled": True, - } - model = DummyFrameworkModelForGit( - sagemaker_session=sagemaker_session, entry_point=entry_point, git_config=git_config - ) - model.prepare_container_def(instance_type=INSTANCE_TYPE) - git_clone_repo.assert_called_with(git_config, entry_point, None, []) - assert model.entry_point == "/tmp/repo_dir/entry_point" - - -@patch( - "sagemaker.git_utils.git_clone_repo", - side_effect=lambda gitconfig, entrypoint, source_dir=None, dependencies=None: { - "entry_point": "/tmp/repo_dir/entry_point", - "source_dir": None, - "dependencies": None, - }, -) -@patch("sagemaker.model.fw_utils.tar_and_upload_dir") -def test_git_support_ssh_no_passphrase_needed( - tar_and_upload_dir, git_clone_repo, sagemaker_session -): - entry_point = "entry_point" - git_config = {"repo": PRIVATE_GIT_REPO_SSH, "branch": PRIVATE_BRANCH, "commit": PRIVATE_COMMIT} - model = DummyFrameworkModelForGit( - sagemaker_session=sagemaker_session, entry_point=entry_point, git_config=git_config - ) - model.prepare_container_def(instance_type=INSTANCE_TYPE) - git_clone_repo.assert_called_with(git_config, entry_point, None, []) - assert model.entry_point == "/tmp/repo_dir/entry_point" - - -@patch( - "sagemaker.git_utils.git_clone_repo", - side_effect=subprocess.CalledProcessError( - returncode=1, cmd="git clone {} {}".format(PRIVATE_GIT_REPO_SSH, REPO_DIR) - ), -) -@patch("sagemaker.model.fw_utils.tar_and_upload_dir") -def test_git_support_ssh_passphrase_required(tar_and_upload_dir, git_clone_repo, sagemaker_session): - entry_point = "entry_point" - git_config = {"repo": PRIVATE_GIT_REPO_SSH, "branch": PRIVATE_BRANCH, "commit": PRIVATE_COMMIT} - with pytest.raises(subprocess.CalledProcessError) as error: - model = DummyFrameworkModelForGit( - sagemaker_session=sagemaker_session, entry_point=entry_point, git_config=git_config - ) - model.prepare_container_def(instance_type=INSTANCE_TYPE) - assert "returned non-zero exit status" in str(error) - - -@patch( - "sagemaker.git_utils.git_clone_repo", - side_effect=lambda gitconfig, entrypoint, source_dir=None, dependencies=None: { - "entry_point": "/tmp/repo_dir/entry_point", - "source_dir": None, - "dependencies": None, - }, -) -@patch("sagemaker.model.fw_utils.tar_and_upload_dir") -def test_git_support_codecommit_with_username_and_password_succeed( - tar_and_upload_dir, git_clone_repo, sagemaker_session -): - entry_point = "entry_point" - git_config = { - "repo": CODECOMMIT_REPO, - "branch": CODECOMMIT_BRANCH, - "username": "username", - "password": "passw0rd!", - } - model = DummyFrameworkModelForGit( - sagemaker_session=sagemaker_session, entry_point=entry_point, git_config=git_config - ) - model.prepare_container_def(instance_type=INSTANCE_TYPE) - git_clone_repo.assert_called_with(git_config, entry_point, None, []) - assert model.entry_point == "/tmp/repo_dir/entry_point" - - -@patch( - "sagemaker.git_utils.git_clone_repo", - side_effect=lambda gitconfig, entrypoint, source_dir=None, dependencies=None: { - "entry_point": "/tmp/repo_dir/entry_point", - "source_dir": None, - "dependencies": None, - }, -) -@patch("sagemaker.model.fw_utils.tar_and_upload_dir") -def test_git_support_codecommit_ssh_no_passphrase_needed( - tar_and_upload_dir, git_clone_repo, sagemaker_session -): - entry_point = "entry_point" - git_config = {"repo": CODECOMMIT_REPO_SSH, "branch": CODECOMMIT_BRANCH} - model = DummyFrameworkModelForGit( - sagemaker_session=sagemaker_session, entry_point=entry_point, git_config=git_config - ) - model.prepare_container_def(instance_type=INSTANCE_TYPE) - git_clone_repo.assert_called_with(git_config, entry_point, None, []) - assert model.entry_point == "/tmp/repo_dir/entry_point" - - -@patch( - "sagemaker.git_utils.git_clone_repo", - side_effect=subprocess.CalledProcessError( - returncode=1, cmd="git clone {} {}".format(PRIVATE_GIT_REPO_SSH, REPO_DIR) - ), -) -@patch("sagemaker.model.fw_utils.tar_and_upload_dir") -def test_git_support_codecommit_ssh_passphrase_required( - tar_and_upload_dir, git_clone_repo, sagemaker_session -): - entry_point = "entry_point" - git_config = {"repo": CODECOMMIT_REPO_SSH, "branch": CODECOMMIT_BRANCH} - with pytest.raises(subprocess.CalledProcessError) as error: - model = DummyFrameworkModelForGit( - sagemaker_session=sagemaker_session, entry_point=entry_point, git_config=git_config - ) - model.prepare_container_def(instance_type=INSTANCE_TYPE) - assert "returned non-zero exit status" in str(error) From e1ec4a1b3fe97bf4eb6e588541cf40b63c684022 Mon Sep 17 00:00:00 2001 From: Lauren Yu <6631887+laurenyu@users.noreply.github.com> Date: Wed, 15 Apr 2020 14:15:06 -0700 Subject: [PATCH 3/4] fix sagemaker_session fixture --- tests/unit/sagemaker/model/test_model.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/unit/sagemaker/model/test_model.py b/tests/unit/sagemaker/model/test_model.py index 84099cc43b..ca0a8c4f60 100644 --- a/tests/unit/sagemaker/model/test_model.py +++ b/tests/unit/sagemaker/model/test_model.py @@ -12,6 +12,7 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import pytest from mock import Mock, patch import sagemaker @@ -21,6 +22,11 @@ MODEL_IMAGE = "mi" +@pytest.fixture +def sagemaker_session(): + return Mock() + + @patch("sagemaker.model.Model._create_sagemaker_model") def test_model_create_transformer(create_sagemaker_model, sagemaker_session): model_name = "auto-generated-model" From a4ac44d197ecc2b2f0db5bf23bc270152b64c88b Mon Sep 17 00:00:00 2001 From: Lauren Yu <6631887+laurenyu@users.noreply.github.com> Date: Wed, 15 Apr 2020 14:31:16 -0700 Subject: [PATCH 4/4] fix sagemaker_session fixture --- tests/unit/sagemaker/model/test_model_package.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/unit/sagemaker/model/test_model_package.py b/tests/unit/sagemaker/model/test_model_package.py index b74671144f..fd9dfc1471 100644 --- a/tests/unit/sagemaker/model/test_model_package.py +++ b/tests/unit/sagemaker/model/test_model_package.py @@ -14,6 +14,7 @@ import copy +import pytest from mock import Mock, patch import sagemaker @@ -52,6 +53,11 @@ } +@pytest.fixture +def sagemaker_session(): + return Mock() + + def test_model_package_enable_network_isolation_with_no_product_id(sagemaker_session): sagemaker_session.sagemaker_client.describe_model_package = Mock( return_value=DESCRIBE_MODEL_PACKAGE_RESPONSE