Skip to content

Commit bb2b8d4

Browse files
committed
change: move unit tests
1 parent 03615ae commit bb2b8d4

File tree

2 files changed

+105
-132
lines changed

2 files changed

+105
-132
lines changed

tests/unit/sagemaker/model/test_model.py

Lines changed: 105 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
14+
from unittest.mock import MagicMock
1415

1516
import pytest
1617
from mock import Mock, patch
1718

1819
import sagemaker
19-
from sagemaker.model import Model
20+
from sagemaker.model import FrameworkModel, Model
2021

2122
MODEL_DATA = "s3://bucket/model.tar.gz"
2223
MODEL_IMAGE = "mi"
@@ -27,10 +28,39 @@
2728
INSTANCE_TYPE = "ml.c4.4xlarge"
2829
ROLE = "some-role"
2930

31+
REGION = "us-west-2"
32+
BUCKET_NAME = "some-bucket-name"
33+
GIT_REPO = "https://github.com/aws/sagemaker-python-sdk.git"
34+
BRANCH = "test-branch-git-config"
35+
COMMIT = "ae15c9d7d5b97ea95ea451e4662ee43da3401d73"
36+
ENTRY_POINT_INFERENCE = "inference.py"
3037

31-
@pytest.fixture
38+
SCRIPT_URI = "s3://codebucket/someprefix/sourcedir.tar.gz"
39+
IMAGE_URI = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.9.0-gpu-py38"
40+
41+
42+
class DummyFrameworkModel(FrameworkModel):
43+
def __init__(self, **kwargs):
44+
super(DummyFrameworkModel, self).__init__(
45+
**kwargs,
46+
)
47+
48+
49+
@pytest.fixture()
3250
def sagemaker_session():
33-
return Mock()
51+
boto_mock = Mock(name="boto_session", region_name=REGION)
52+
sms = MagicMock(
53+
name="sagemaker_session",
54+
boto_session=boto_mock,
55+
boto_region_name=REGION,
56+
config=None,
57+
local_mode=False,
58+
s3_client=None,
59+
s3_resource=None,
60+
)
61+
sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)
62+
63+
return sms
3464

3565

3666
def test_prepare_container_def_with_model_data():
@@ -345,3 +375,75 @@ def test_delete_model_no_name(sagemaker_session):
345375
):
346376
model.delete_model()
347377
sagemaker_session.delete_model.assert_not_called()
378+
379+
380+
@patch("time.strftime", MagicMock(return_value=TIMESTAMP))
381+
@patch("sagemaker.utils.repack_model")
382+
def test_script_mode_model_same_calls_as_framework(repack_model, sagemaker_session):
383+
t = Model(
384+
entry_point=ENTRY_POINT_INFERENCE,
385+
role=ROLE,
386+
sagemaker_session=sagemaker_session,
387+
source_dir=SCRIPT_URI,
388+
image_uri=IMAGE_URI,
389+
model_data=MODEL_DATA,
390+
)
391+
t.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT)
392+
393+
assert len(sagemaker_session.create_model.call_args_list) == 1
394+
assert len(sagemaker_session.endpoint_from_production_variants.call_args_list) == 1
395+
assert len(repack_model.call_args_list) == 1
396+
397+
generic_model_create_model_args = sagemaker_session.create_model.call_args_list
398+
generic_model_endpoint_from_production_variants_args = (
399+
sagemaker_session.endpoint_from_production_variants.call_args_list
400+
)
401+
generic_model_repack_model_args = repack_model.call_args_list
402+
403+
sagemaker_session.create_model.reset_mock()
404+
sagemaker_session.endpoint_from_production_variants.reset_mock()
405+
repack_model.reset_mock()
406+
407+
t = DummyFrameworkModel(
408+
entry_point=ENTRY_POINT_INFERENCE,
409+
role=ROLE,
410+
sagemaker_session=sagemaker_session,
411+
source_dir=SCRIPT_URI,
412+
image_uri=IMAGE_URI,
413+
model_data=MODEL_DATA,
414+
)
415+
t.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT)
416+
417+
assert generic_model_create_model_args == sagemaker_session.create_model.call_args_list
418+
assert (
419+
generic_model_endpoint_from_production_variants_args
420+
== sagemaker_session.endpoint_from_production_variants.call_args_list
421+
)
422+
assert generic_model_repack_model_args == repack_model.call_args_list
423+
424+
425+
@patch("sagemaker.git_utils.git_clone_repo")
426+
@patch("sagemaker.model.fw_utils.tar_and_upload_dir")
427+
def test_git_support_succeed_model_class(tar_and_upload_dir, git_clone_repo, sagemaker_session):
428+
git_clone_repo.side_effect = lambda gitconfig, entrypoint, sourcedir, dependency: {
429+
"entry_point": "entry_point",
430+
"source_dir": "/tmp/repo_dir/source_dir",
431+
"dependencies": ["/tmp/repo_dir/foo", "/tmp/repo_dir/bar"],
432+
}
433+
entry_point = "entry_point"
434+
source_dir = "source_dir"
435+
dependencies = ["foo", "bar"]
436+
git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT}
437+
model = Model(
438+
sagemaker_session=sagemaker_session,
439+
entry_point=entry_point,
440+
source_dir=source_dir,
441+
dependencies=dependencies,
442+
git_config=git_config,
443+
image_uri=IMAGE_URI,
444+
)
445+
model.prepare_container_def(instance_type=INSTANCE_TYPE)
446+
git_clone_repo.assert_called_with(git_config, entry_point, source_dir, dependencies)
447+
assert model.entry_point == "entry_point"
448+
assert model.source_dir == "/tmp/repo_dir/source_dir"
449+
assert model.dependencies == ["/tmp/repo_dir/foo", "/tmp/repo_dir/bar"]

tests/unit/test_model.py

Lines changed: 0 additions & 129 deletions
This file was deleted.

0 commit comments

Comments
 (0)