Skip to content

Commit 36987c4

Browse files
authored
infra: split model unit tests by Model, FrameworkModel, and ModelPackage (#1417)
1 parent ff0b9f4 commit 36987c4

File tree

3 files changed

+261
-211
lines changed

3 files changed

+261
-211
lines changed

tests/unit/test_model.py renamed to tests/unit/sagemaker/model/test_framework_model.py

Lines changed: 1 addition & 211 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,10 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15-
import copy
1615
import os
1716
import subprocess
1817

19-
import sagemaker
20-
from sagemaker.model import FrameworkModel, Model, ModelPackage
18+
from sagemaker.model import FrameworkModel
2119
from sagemaker.predictor import RealTimePredictor
2220

2321
import pytest
@@ -53,39 +51,6 @@
5351
CODECOMMIT_BRANCH = "master"
5452
REPO_DIR = "/tmp/repo_dir"
5553

56-
57-
DESCRIBE_MODEL_PACKAGE_RESPONSE = {
58-
"InferenceSpecification": {
59-
"SupportedResponseMIMETypes": ["text"],
60-
"SupportedContentTypes": ["text/csv"],
61-
"SupportedTransformInstanceTypes": ["ml.m4.xlarge", "ml.m4.2xlarge"],
62-
"Containers": [
63-
{
64-
"Image": "1.dkr.ecr.us-east-2.amazonaws.com/decision-trees-sample:latest",
65-
"ImageDigest": "sha256:1234556789",
66-
"ModelDataUrl": "s3://bucket/output/model.tar.gz",
67-
}
68-
],
69-
"SupportedRealtimeInferenceInstanceTypes": ["ml.m4.xlarge", "ml.m4.2xlarge"],
70-
},
71-
"ModelPackageDescription": "Model Package created from training with "
72-
"arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees",
73-
"CreationTime": 1542752036.687,
74-
"ModelPackageArn": "arn:aws:sagemaker:us-east-2:123:model-package/mp-scikit-decision-trees",
75-
"ModelPackageStatusDetails": {"ValidationStatuses": [], "ImageScanStatuses": []},
76-
"SourceAlgorithmSpecification": {
77-
"SourceAlgorithms": [
78-
{
79-
"ModelDataUrl": "s3://bucket/output/model.tar.gz",
80-
"AlgorithmName": "arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees",
81-
}
82-
]
83-
},
84-
"ModelPackageStatus": "Completed",
85-
"ModelPackageName": "mp-scikit-decision-trees-1542410022-2018-11-20-22-13-56-502",
86-
"CertifyForMarketplace": False,
87-
}
88-
8954
DESCRIBE_COMPILATION_JOB_RESPONSE = {
9055
"CompilationJobStatus": "Completed",
9156
"ModelArtifacts": {"S3ModelArtifacts": "s3://output-path/model.tar.gz"},
@@ -417,181 +382,6 @@ def test_model_enable_network_isolation(sagemaker_session):
417382
assert model.enable_network_isolation() is False
418383

419384

420-
@patch("sagemaker.model.Model._create_sagemaker_model")
421-
def test_model_create_transformer(create_sagemaker_model, sagemaker_session):
422-
model_name = "auto-generated-model"
423-
model = Model(MODEL_DATA, MODEL_IMAGE, name=model_name, sagemaker_session=sagemaker_session)
424-
425-
instance_type = "ml.m4.xlarge"
426-
transformer = model.transformer(instance_count=1, instance_type=instance_type)
427-
428-
create_sagemaker_model.assert_called_with(instance_type, tags=None)
429-
430-
assert isinstance(transformer, sagemaker.transformer.Transformer)
431-
assert transformer.model_name == model_name
432-
assert transformer.instance_type == instance_type
433-
assert transformer.instance_count == 1
434-
assert transformer.sagemaker_session == sagemaker_session
435-
assert transformer.base_transform_job_name == model_name
436-
437-
assert transformer.strategy is None
438-
assert transformer.env is None
439-
assert transformer.output_path is None
440-
assert transformer.output_kms_key is None
441-
assert transformer.accept is None
442-
assert transformer.assemble_with is None
443-
assert transformer.volume_kms_key is None
444-
assert transformer.max_concurrent_transforms is None
445-
assert transformer.max_payload is None
446-
assert transformer.tags is None
447-
448-
449-
@patch("sagemaker.model.Model._create_sagemaker_model")
450-
def test_model_create_transformer_optional_params(create_sagemaker_model, sagemaker_session):
451-
model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=sagemaker_session)
452-
453-
instance_type = "ml.m4.xlarge"
454-
strategy = "MultiRecord"
455-
assemble_with = "Line"
456-
output_path = "s3://bucket/path"
457-
kms_key = "key"
458-
accept = "text/csv"
459-
env = {"test": True}
460-
max_concurrent_transforms = 1
461-
max_payload = 6
462-
tags = [{"Key": "k", "Value": "v"}]
463-
464-
transformer = model.transformer(
465-
instance_count=1,
466-
instance_type=instance_type,
467-
strategy=strategy,
468-
assemble_with=assemble_with,
469-
output_path=output_path,
470-
output_kms_key=kms_key,
471-
accept=accept,
472-
env=env,
473-
max_concurrent_transforms=max_concurrent_transforms,
474-
max_payload=max_payload,
475-
tags=tags,
476-
volume_kms_key=kms_key,
477-
)
478-
479-
create_sagemaker_model.assert_called_with(instance_type, tags=tags)
480-
481-
assert isinstance(transformer, sagemaker.transformer.Transformer)
482-
assert transformer.strategy == strategy
483-
assert transformer.assemble_with == assemble_with
484-
assert transformer.output_path == output_path
485-
assert transformer.output_kms_key == kms_key
486-
assert transformer.accept == accept
487-
assert transformer.max_concurrent_transforms == max_concurrent_transforms
488-
assert transformer.max_payload == max_payload
489-
assert transformer.env == env
490-
assert transformer.tags == tags
491-
assert transformer.volume_kms_key == kms_key
492-
493-
494-
@patch("sagemaker.model.Model._create_sagemaker_model")
495-
def test_model_create_transformer_network_isolation(create_sagemaker_model, sagemaker_session):
496-
model = Model(
497-
MODEL_DATA, MODEL_IMAGE, sagemaker_session=sagemaker_session, enable_network_isolation=True
498-
)
499-
500-
transformer = model.transformer(1, "ml.m4.xlarge", env={"should_be": "overwritten"})
501-
assert transformer.env is None
502-
503-
504-
@patch("sagemaker.session.Session")
505-
@patch("sagemaker.local.LocalSession")
506-
@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock())
507-
def test_transformer_creates_correct_session(local_session, session):
508-
model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=None)
509-
transformer = model.transformer(instance_count=1, instance_type="local")
510-
assert model.sagemaker_session == local_session.return_value
511-
assert transformer.sagemaker_session == local_session.return_value
512-
513-
model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=None)
514-
transformer = model.transformer(instance_count=1, instance_type="ml.m5.xlarge")
515-
assert model.sagemaker_session == session.return_value
516-
assert transformer.sagemaker_session == session.return_value
517-
518-
519-
def test_model_package_enable_network_isolation_with_no_product_id(sagemaker_session):
520-
sagemaker_session.sagemaker_client.describe_model_package = Mock(
521-
return_value=DESCRIBE_MODEL_PACKAGE_RESPONSE
522-
)
523-
524-
model_package = ModelPackage(
525-
role="role", model_package_arn="my-model-package", sagemaker_session=sagemaker_session
526-
)
527-
assert model_package.enable_network_isolation() is False
528-
529-
530-
def test_model_package_enable_network_isolation_with_product_id(sagemaker_session):
531-
model_package_response = copy.deepcopy(DESCRIBE_MODEL_PACKAGE_RESPONSE)
532-
model_package_response["InferenceSpecification"]["Containers"].append(
533-
{
534-
"Image": "1.dkr.ecr.us-east-2.amazonaws.com/some-container:latest",
535-
"ModelDataUrl": "s3://bucket/output/model.tar.gz",
536-
"ProductId": "some-product-id",
537-
}
538-
)
539-
sagemaker_session.sagemaker_client.describe_model_package = Mock(
540-
return_value=model_package_response
541-
)
542-
543-
model_package = ModelPackage(
544-
role="role", model_package_arn="my-model-package", sagemaker_session=sagemaker_session
545-
)
546-
assert model_package.enable_network_isolation() is True
547-
548-
549-
@patch("sagemaker.model.ModelPackage._create_sagemaker_model", Mock())
550-
def test_model_package_create_transformer(sagemaker_session):
551-
sagemaker_session.sagemaker_client.describe_model_package = Mock(
552-
return_value=DESCRIBE_MODEL_PACKAGE_RESPONSE
553-
)
554-
555-
model_package = ModelPackage(
556-
role="role", model_package_arn="my-model-package", sagemaker_session=sagemaker_session
557-
)
558-
model_package.name = "auto-generated-model"
559-
transformer = model_package.transformer(
560-
instance_count=1, instance_type="ml.m4.xlarge", env={"test": True}
561-
)
562-
assert isinstance(transformer, sagemaker.transformer.Transformer)
563-
assert transformer.model_name == "auto-generated-model"
564-
assert transformer.instance_type == "ml.m4.xlarge"
565-
assert transformer.env == {"test": True}
566-
567-
568-
@patch("sagemaker.model.ModelPackage._create_sagemaker_model", Mock())
569-
def test_model_package_create_transformer_with_product_id(sagemaker_session):
570-
model_package_response = copy.deepcopy(DESCRIBE_MODEL_PACKAGE_RESPONSE)
571-
model_package_response["InferenceSpecification"]["Containers"].append(
572-
{
573-
"Image": "1.dkr.ecr.us-east-2.amazonaws.com/some-container:latest",
574-
"ModelDataUrl": "s3://bucket/output/model.tar.gz",
575-
"ProductId": "some-product-id",
576-
}
577-
)
578-
sagemaker_session.sagemaker_client.describe_model_package = Mock(
579-
return_value=model_package_response
580-
)
581-
582-
model_package = ModelPackage(
583-
role="role", model_package_arn="my-model-package", sagemaker_session=sagemaker_session
584-
)
585-
model_package.name = "auto-generated-model"
586-
transformer = model_package.transformer(
587-
instance_count=1, instance_type="ml.m4.xlarge", env={"test": True}
588-
)
589-
assert isinstance(transformer, sagemaker.transformer.Transformer)
590-
assert transformer.model_name == "auto-generated-model"
591-
assert transformer.instance_type == "ml.m4.xlarge"
592-
assert transformer.env is None
593-
594-
595385
@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock())
596386
@patch("time.strftime", MagicMock(return_value=TIMESTAMP))
597387
def test_model_delete_model(sagemaker_session, tmpdir):
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pytest
16+
from mock import Mock, patch
17+
18+
import sagemaker
19+
from sagemaker.model import Model
20+
21+
MODEL_DATA = "s3://bucket/model.tar.gz"
22+
MODEL_IMAGE = "mi"
23+
24+
25+
@pytest.fixture
26+
def sagemaker_session():
27+
return Mock()
28+
29+
30+
@patch("sagemaker.model.Model._create_sagemaker_model")
31+
def test_model_create_transformer(create_sagemaker_model, sagemaker_session):
32+
model_name = "auto-generated-model"
33+
model = Model(MODEL_DATA, MODEL_IMAGE, name=model_name, sagemaker_session=sagemaker_session)
34+
35+
instance_type = "ml.m4.xlarge"
36+
transformer = model.transformer(instance_count=1, instance_type=instance_type)
37+
38+
create_sagemaker_model.assert_called_with(instance_type, tags=None)
39+
40+
assert isinstance(transformer, sagemaker.transformer.Transformer)
41+
assert transformer.model_name == model_name
42+
assert transformer.instance_type == instance_type
43+
assert transformer.instance_count == 1
44+
assert transformer.sagemaker_session == sagemaker_session
45+
assert transformer.base_transform_job_name == model_name
46+
47+
assert transformer.strategy is None
48+
assert transformer.env is None
49+
assert transformer.output_path is None
50+
assert transformer.output_kms_key is None
51+
assert transformer.accept is None
52+
assert transformer.assemble_with is None
53+
assert transformer.volume_kms_key is None
54+
assert transformer.max_concurrent_transforms is None
55+
assert transformer.max_payload is None
56+
assert transformer.tags is None
57+
58+
59+
@patch("sagemaker.model.Model._create_sagemaker_model")
60+
def test_model_create_transformer_optional_params(create_sagemaker_model, sagemaker_session):
61+
model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=sagemaker_session)
62+
63+
instance_type = "ml.m4.xlarge"
64+
strategy = "MultiRecord"
65+
assemble_with = "Line"
66+
output_path = "s3://bucket/path"
67+
kms_key = "key"
68+
accept = "text/csv"
69+
env = {"test": True}
70+
max_concurrent_transforms = 1
71+
max_payload = 6
72+
tags = [{"Key": "k", "Value": "v"}]
73+
74+
transformer = model.transformer(
75+
instance_count=1,
76+
instance_type=instance_type,
77+
strategy=strategy,
78+
assemble_with=assemble_with,
79+
output_path=output_path,
80+
output_kms_key=kms_key,
81+
accept=accept,
82+
env=env,
83+
max_concurrent_transforms=max_concurrent_transforms,
84+
max_payload=max_payload,
85+
tags=tags,
86+
volume_kms_key=kms_key,
87+
)
88+
89+
create_sagemaker_model.assert_called_with(instance_type, tags=tags)
90+
91+
assert isinstance(transformer, sagemaker.transformer.Transformer)
92+
assert transformer.strategy == strategy
93+
assert transformer.assemble_with == assemble_with
94+
assert transformer.output_path == output_path
95+
assert transformer.output_kms_key == kms_key
96+
assert transformer.accept == accept
97+
assert transformer.max_concurrent_transforms == max_concurrent_transforms
98+
assert transformer.max_payload == max_payload
99+
assert transformer.env == env
100+
assert transformer.tags == tags
101+
assert transformer.volume_kms_key == kms_key
102+
103+
104+
@patch("sagemaker.model.Model._create_sagemaker_model")
105+
def test_model_create_transformer_network_isolation(create_sagemaker_model, sagemaker_session):
106+
model = Model(
107+
MODEL_DATA, MODEL_IMAGE, sagemaker_session=sagemaker_session, enable_network_isolation=True
108+
)
109+
110+
transformer = model.transformer(1, "ml.m4.xlarge", env={"should_be": "overwritten"})
111+
assert transformer.env is None
112+
113+
114+
@patch("sagemaker.session.Session")
115+
@patch("sagemaker.local.LocalSession")
116+
@patch("sagemaker.fw_utils.tar_and_upload_dir", Mock())
117+
def test_transformer_creates_correct_session(local_session, session):
118+
model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=None)
119+
transformer = model.transformer(instance_count=1, instance_type="local")
120+
assert model.sagemaker_session == local_session.return_value
121+
assert transformer.sagemaker_session == local_session.return_value
122+
123+
model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=None)
124+
transformer = model.transformer(instance_count=1, instance_type="ml.m5.xlarge")
125+
assert model.sagemaker_session == session.return_value
126+
assert transformer.sagemaker_session == session.return_value

0 commit comments

Comments
 (0)