diff --git a/tests/unit/test_model.py b/tests/unit/sagemaker/model/test_framework_model.py similarity index 78% rename from tests/unit/test_model.py rename to tests/unit/sagemaker/model/test_framework_model.py index 0ab00c3c2e..635a59ae79 100644 --- a/tests/unit/test_model.py +++ b/tests/unit/sagemaker/model/test_framework_model.py @@ -12,12 +12,10 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import -import copy import os import subprocess -import sagemaker -from sagemaker.model import FrameworkModel, Model, ModelPackage +from sagemaker.model import FrameworkModel from sagemaker.predictor import RealTimePredictor import pytest @@ -53,39 +51,6 @@ CODECOMMIT_BRANCH = "master" REPO_DIR = "/tmp/repo_dir" - -DESCRIBE_MODEL_PACKAGE_RESPONSE = { - "InferenceSpecification": { - "SupportedResponseMIMETypes": ["text"], - "SupportedContentTypes": ["text/csv"], - "SupportedTransformInstanceTypes": ["ml.m4.xlarge", "ml.m4.2xlarge"], - "Containers": [ - { - "Image": "1.dkr.ecr.us-east-2.amazonaws.com/decision-trees-sample:latest", - "ImageDigest": "sha256:1234556789", - "ModelDataUrl": "s3://bucket/output/model.tar.gz", - } - ], - "SupportedRealtimeInferenceInstanceTypes": ["ml.m4.xlarge", "ml.m4.2xlarge"], - }, - "ModelPackageDescription": "Model Package created from training with " - "arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", - "CreationTime": 1542752036.687, - "ModelPackageArn": "arn:aws:sagemaker:us-east-2:123:model-package/mp-scikit-decision-trees", - "ModelPackageStatusDetails": {"ValidationStatuses": [], "ImageScanStatuses": []}, - "SourceAlgorithmSpecification": { - "SourceAlgorithms": [ - { - "ModelDataUrl": "s3://bucket/output/model.tar.gz", - "AlgorithmName": "arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", - } - ] - }, - "ModelPackageStatus": "Completed", - "ModelPackageName": "mp-scikit-decision-trees-1542410022-2018-11-20-22-13-56-502", - "CertifyForMarketplace": False, -} - DESCRIBE_COMPILATION_JOB_RESPONSE = { "CompilationJobStatus": "Completed", "ModelArtifacts": {"S3ModelArtifacts": "s3://output-path/model.tar.gz"}, @@ -417,181 +382,6 @@ def test_model_enable_network_isolation(sagemaker_session): assert model.enable_network_isolation() is False -@patch("sagemaker.model.Model._create_sagemaker_model") -def test_model_create_transformer(create_sagemaker_model, sagemaker_session): - model_name = "auto-generated-model" - model = Model(MODEL_DATA, MODEL_IMAGE, name=model_name, sagemaker_session=sagemaker_session) - - instance_type = "ml.m4.xlarge" - transformer = model.transformer(instance_count=1, instance_type=instance_type) - - create_sagemaker_model.assert_called_with(instance_type, tags=None) - - assert isinstance(transformer, sagemaker.transformer.Transformer) - assert transformer.model_name == model_name - assert transformer.instance_type == instance_type - assert transformer.instance_count == 1 - assert transformer.sagemaker_session == sagemaker_session - assert transformer.base_transform_job_name == model_name - - assert transformer.strategy is None - assert transformer.env is None - assert transformer.output_path is None - assert transformer.output_kms_key is None - assert transformer.accept is None - assert transformer.assemble_with is None - assert transformer.volume_kms_key is None - assert transformer.max_concurrent_transforms is None - assert transformer.max_payload is None - assert transformer.tags is None - - -@patch("sagemaker.model.Model._create_sagemaker_model") -def test_model_create_transformer_optional_params(create_sagemaker_model, sagemaker_session): - model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=sagemaker_session) - - instance_type = "ml.m4.xlarge" - strategy = "MultiRecord" - assemble_with = "Line" - output_path = "s3://bucket/path" - kms_key = "key" - accept = "text/csv" - env = {"test": True} - max_concurrent_transforms = 1 - max_payload = 6 - tags = [{"Key": "k", "Value": "v"}] - - transformer = model.transformer( - instance_count=1, - instance_type=instance_type, - strategy=strategy, - assemble_with=assemble_with, - output_path=output_path, - output_kms_key=kms_key, - accept=accept, - env=env, - max_concurrent_transforms=max_concurrent_transforms, - max_payload=max_payload, - tags=tags, - volume_kms_key=kms_key, - ) - - create_sagemaker_model.assert_called_with(instance_type, tags=tags) - - assert isinstance(transformer, sagemaker.transformer.Transformer) - assert transformer.strategy == strategy - assert transformer.assemble_with == assemble_with - assert transformer.output_path == output_path - assert transformer.output_kms_key == kms_key - assert transformer.accept == accept - assert transformer.max_concurrent_transforms == max_concurrent_transforms - assert transformer.max_payload == max_payload - assert transformer.env == env - assert transformer.tags == tags - assert transformer.volume_kms_key == kms_key - - -@patch("sagemaker.model.Model._create_sagemaker_model") -def test_model_create_transformer_network_isolation(create_sagemaker_model, sagemaker_session): - model = Model( - MODEL_DATA, MODEL_IMAGE, sagemaker_session=sagemaker_session, enable_network_isolation=True - ) - - transformer = model.transformer(1, "ml.m4.xlarge", env={"should_be": "overwritten"}) - assert transformer.env is None - - -@patch("sagemaker.session.Session") -@patch("sagemaker.local.LocalSession") -@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) -def test_transformer_creates_correct_session(local_session, session): - model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=None) - transformer = model.transformer(instance_count=1, instance_type="local") - assert model.sagemaker_session == local_session.return_value - assert transformer.sagemaker_session == local_session.return_value - - model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=None) - transformer = model.transformer(instance_count=1, instance_type="ml.m5.xlarge") - assert model.sagemaker_session == session.return_value - assert transformer.sagemaker_session == session.return_value - - -def test_model_package_enable_network_isolation_with_no_product_id(sagemaker_session): - sagemaker_session.sagemaker_client.describe_model_package = Mock( - return_value=DESCRIBE_MODEL_PACKAGE_RESPONSE - ) - - model_package = ModelPackage( - role="role", model_package_arn="my-model-package", sagemaker_session=sagemaker_session - ) - assert model_package.enable_network_isolation() is False - - -def test_model_package_enable_network_isolation_with_product_id(sagemaker_session): - model_package_response = copy.deepcopy(DESCRIBE_MODEL_PACKAGE_RESPONSE) - model_package_response["InferenceSpecification"]["Containers"].append( - { - "Image": "1.dkr.ecr.us-east-2.amazonaws.com/some-container:latest", - "ModelDataUrl": "s3://bucket/output/model.tar.gz", - "ProductId": "some-product-id", - } - ) - sagemaker_session.sagemaker_client.describe_model_package = Mock( - return_value=model_package_response - ) - - model_package = ModelPackage( - role="role", model_package_arn="my-model-package", sagemaker_session=sagemaker_session - ) - assert model_package.enable_network_isolation() is True - - -@patch("sagemaker.model.ModelPackage._create_sagemaker_model", Mock()) -def test_model_package_create_transformer(sagemaker_session): - sagemaker_session.sagemaker_client.describe_model_package = Mock( - return_value=DESCRIBE_MODEL_PACKAGE_RESPONSE - ) - - model_package = ModelPackage( - role="role", model_package_arn="my-model-package", sagemaker_session=sagemaker_session - ) - model_package.name = "auto-generated-model" - transformer = model_package.transformer( - instance_count=1, instance_type="ml.m4.xlarge", env={"test": True} - ) - assert isinstance(transformer, sagemaker.transformer.Transformer) - assert transformer.model_name == "auto-generated-model" - assert transformer.instance_type == "ml.m4.xlarge" - assert transformer.env == {"test": True} - - -@patch("sagemaker.model.ModelPackage._create_sagemaker_model", Mock()) -def test_model_package_create_transformer_with_product_id(sagemaker_session): - model_package_response = copy.deepcopy(DESCRIBE_MODEL_PACKAGE_RESPONSE) - model_package_response["InferenceSpecification"]["Containers"].append( - { - "Image": "1.dkr.ecr.us-east-2.amazonaws.com/some-container:latest", - "ModelDataUrl": "s3://bucket/output/model.tar.gz", - "ProductId": "some-product-id", - } - ) - sagemaker_session.sagemaker_client.describe_model_package = Mock( - return_value=model_package_response - ) - - model_package = ModelPackage( - role="role", model_package_arn="my-model-package", sagemaker_session=sagemaker_session - ) - model_package.name = "auto-generated-model" - transformer = model_package.transformer( - instance_count=1, instance_type="ml.m4.xlarge", env={"test": True} - ) - assert isinstance(transformer, sagemaker.transformer.Transformer) - assert transformer.model_name == "auto-generated-model" - assert transformer.instance_type == "ml.m4.xlarge" - assert transformer.env is None - - @patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) @patch("time.strftime", MagicMock(return_value=TIMESTAMP)) def test_model_delete_model(sagemaker_session, tmpdir): diff --git a/tests/unit/sagemaker/model/test_model.py b/tests/unit/sagemaker/model/test_model.py new file mode 100644 index 0000000000..ca0a8c4f60 --- /dev/null +++ b/tests/unit/sagemaker/model/test_model.py @@ -0,0 +1,126 @@ +# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +from mock import Mock, patch + +import sagemaker +from sagemaker.model import Model + +MODEL_DATA = "s3://bucket/model.tar.gz" +MODEL_IMAGE = "mi" + + +@pytest.fixture +def sagemaker_session(): + return Mock() + + +@patch("sagemaker.model.Model._create_sagemaker_model") +def test_model_create_transformer(create_sagemaker_model, sagemaker_session): + model_name = "auto-generated-model" + model = Model(MODEL_DATA, MODEL_IMAGE, name=model_name, sagemaker_session=sagemaker_session) + + instance_type = "ml.m4.xlarge" + transformer = model.transformer(instance_count=1, instance_type=instance_type) + + create_sagemaker_model.assert_called_with(instance_type, tags=None) + + assert isinstance(transformer, sagemaker.transformer.Transformer) + assert transformer.model_name == model_name + assert transformer.instance_type == instance_type + assert transformer.instance_count == 1 + assert transformer.sagemaker_session == sagemaker_session + assert transformer.base_transform_job_name == model_name + + assert transformer.strategy is None + assert transformer.env is None + assert transformer.output_path is None + assert transformer.output_kms_key is None + assert transformer.accept is None + assert transformer.assemble_with is None + assert transformer.volume_kms_key is None + assert transformer.max_concurrent_transforms is None + assert transformer.max_payload is None + assert transformer.tags is None + + +@patch("sagemaker.model.Model._create_sagemaker_model") +def test_model_create_transformer_optional_params(create_sagemaker_model, sagemaker_session): + model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=sagemaker_session) + + instance_type = "ml.m4.xlarge" + strategy = "MultiRecord" + assemble_with = "Line" + output_path = "s3://bucket/path" + kms_key = "key" + accept = "text/csv" + env = {"test": True} + max_concurrent_transforms = 1 + max_payload = 6 + tags = [{"Key": "k", "Value": "v"}] + + transformer = model.transformer( + instance_count=1, + instance_type=instance_type, + strategy=strategy, + assemble_with=assemble_with, + output_path=output_path, + output_kms_key=kms_key, + accept=accept, + env=env, + max_concurrent_transforms=max_concurrent_transforms, + max_payload=max_payload, + tags=tags, + volume_kms_key=kms_key, + ) + + create_sagemaker_model.assert_called_with(instance_type, tags=tags) + + assert isinstance(transformer, sagemaker.transformer.Transformer) + assert transformer.strategy == strategy + assert transformer.assemble_with == assemble_with + assert transformer.output_path == output_path + assert transformer.output_kms_key == kms_key + assert transformer.accept == accept + assert transformer.max_concurrent_transforms == max_concurrent_transforms + assert transformer.max_payload == max_payload + assert transformer.env == env + assert transformer.tags == tags + assert transformer.volume_kms_key == kms_key + + +@patch("sagemaker.model.Model._create_sagemaker_model") +def test_model_create_transformer_network_isolation(create_sagemaker_model, sagemaker_session): + model = Model( + MODEL_DATA, MODEL_IMAGE, sagemaker_session=sagemaker_session, enable_network_isolation=True + ) + + transformer = model.transformer(1, "ml.m4.xlarge", env={"should_be": "overwritten"}) + assert transformer.env is None + + +@patch("sagemaker.session.Session") +@patch("sagemaker.local.LocalSession") +@patch("sagemaker.fw_utils.tar_and_upload_dir", Mock()) +def test_transformer_creates_correct_session(local_session, session): + model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=None) + transformer = model.transformer(instance_count=1, instance_type="local") + assert model.sagemaker_session == local_session.return_value + assert transformer.sagemaker_session == local_session.return_value + + model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=None) + transformer = model.transformer(instance_count=1, instance_type="ml.m5.xlarge") + assert model.sagemaker_session == session.return_value + assert transformer.sagemaker_session == session.return_value diff --git a/tests/unit/sagemaker/model/test_model_package.py b/tests/unit/sagemaker/model/test_model_package.py new file mode 100644 index 0000000000..fd9dfc1471 --- /dev/null +++ b/tests/unit/sagemaker/model/test_model_package.py @@ -0,0 +1,134 @@ +# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import copy + +import pytest +from mock import Mock, patch + +import sagemaker +from sagemaker.model import ModelPackage + +DESCRIBE_MODEL_PACKAGE_RESPONSE = { + "InferenceSpecification": { + "SupportedResponseMIMETypes": ["text"], + "SupportedContentTypes": ["text/csv"], + "SupportedTransformInstanceTypes": ["ml.m4.xlarge", "ml.m4.2xlarge"], + "Containers": [ + { + "Image": "1.dkr.ecr.us-east-2.amazonaws.com/decision-trees-sample:latest", + "ImageDigest": "sha256:1234556789", + "ModelDataUrl": "s3://bucket/output/model.tar.gz", + } + ], + "SupportedRealtimeInferenceInstanceTypes": ["ml.m4.xlarge", "ml.m4.2xlarge"], + }, + "ModelPackageDescription": "Model Package created from training with " + "arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", + "CreationTime": 1542752036.687, + "ModelPackageArn": "arn:aws:sagemaker:us-east-2:123:model-package/mp-scikit-decision-trees", + "ModelPackageStatusDetails": {"ValidationStatuses": [], "ImageScanStatuses": []}, + "SourceAlgorithmSpecification": { + "SourceAlgorithms": [ + { + "ModelDataUrl": "s3://bucket/output/model.tar.gz", + "AlgorithmName": "arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", + } + ] + }, + "ModelPackageStatus": "Completed", + "ModelPackageName": "mp-scikit-decision-trees-1542410022-2018-11-20-22-13-56-502", + "CertifyForMarketplace": False, +} + + +@pytest.fixture +def sagemaker_session(): + return Mock() + + +def test_model_package_enable_network_isolation_with_no_product_id(sagemaker_session): + sagemaker_session.sagemaker_client.describe_model_package = Mock( + return_value=DESCRIBE_MODEL_PACKAGE_RESPONSE + ) + + model_package = ModelPackage( + role="role", model_package_arn="my-model-package", sagemaker_session=sagemaker_session + ) + assert model_package.enable_network_isolation() is False + + +def test_model_package_enable_network_isolation_with_product_id(sagemaker_session): + model_package_response = copy.deepcopy(DESCRIBE_MODEL_PACKAGE_RESPONSE) + model_package_response["InferenceSpecification"]["Containers"].append( + { + "Image": "1.dkr.ecr.us-east-2.amazonaws.com/some-container:latest", + "ModelDataUrl": "s3://bucket/output/model.tar.gz", + "ProductId": "some-product-id", + } + ) + sagemaker_session.sagemaker_client.describe_model_package = Mock( + return_value=model_package_response + ) + + model_package = ModelPackage( + role="role", model_package_arn="my-model-package", sagemaker_session=sagemaker_session + ) + assert model_package.enable_network_isolation() is True + + +@patch("sagemaker.model.ModelPackage._create_sagemaker_model", Mock()) +def test_model_package_create_transformer(sagemaker_session): + sagemaker_session.sagemaker_client.describe_model_package = Mock( + return_value=DESCRIBE_MODEL_PACKAGE_RESPONSE + ) + + model_package = ModelPackage( + role="role", model_package_arn="my-model-package", sagemaker_session=sagemaker_session + ) + model_package.name = "auto-generated-model" + transformer = model_package.transformer( + instance_count=1, instance_type="ml.m4.xlarge", env={"test": True} + ) + assert isinstance(transformer, sagemaker.transformer.Transformer) + assert transformer.model_name == "auto-generated-model" + assert transformer.instance_type == "ml.m4.xlarge" + assert transformer.env == {"test": True} + + +@patch("sagemaker.model.ModelPackage._create_sagemaker_model", Mock()) +def test_model_package_create_transformer_with_product_id(sagemaker_session): + model_package_response = copy.deepcopy(DESCRIBE_MODEL_PACKAGE_RESPONSE) + model_package_response["InferenceSpecification"]["Containers"].append( + { + "Image": "1.dkr.ecr.us-east-2.amazonaws.com/some-container:latest", + "ModelDataUrl": "s3://bucket/output/model.tar.gz", + "ProductId": "some-product-id", + } + ) + sagemaker_session.sagemaker_client.describe_model_package = Mock( + return_value=model_package_response + ) + + model_package = ModelPackage( + role="role", model_package_arn="my-model-package", sagemaker_session=sagemaker_session + ) + model_package.name = "auto-generated-model" + transformer = model_package.transformer( + instance_count=1, instance_type="ml.m4.xlarge", env={"test": True} + ) + assert isinstance(transformer, sagemaker.transformer.Transformer) + assert transformer.model_name == "auto-generated-model" + assert transformer.instance_type == "ml.m4.xlarge" + assert transformer.env is None