Skip to content

Commit 167b723

Browse files
authored
feature: script mode for model class (#2841)
1 parent 00f23e6 commit 167b723

File tree

9 files changed

+351
-125
lines changed

9 files changed

+351
-125
lines changed

src/sagemaker/chainer/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
168168
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
169169
self._upload_code(deploy_key_prefix)
170170
deploy_env = dict(self.env)
171-
deploy_env.update(self._framework_env_vars())
171+
deploy_env.update(self._script_mode_env_vars())
172172

173173
if self.model_server_workers:
174174
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)

src/sagemaker/huggingface/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
273273
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
274274
self._upload_code(deploy_key_prefix, repack=True)
275275
deploy_env = dict(self.env)
276-
deploy_env.update(self._framework_env_vars())
276+
deploy_env.update(self._script_mode_env_vars())
277277

278278
if self.model_server_workers:
279279
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)

src/sagemaker/model.py

Lines changed: 239 additions & 115 deletions
Large diffs are not rendered by default.

src/sagemaker/mxnet/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
244244
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
245245
self._upload_code(deploy_key_prefix, self._is_mms_version())
246246
deploy_env = dict(self.env)
247-
deploy_env.update(self._framework_env_vars())
247+
deploy_env.update(self._script_mode_env_vars())
248248

249249
if self.model_server_workers:
250250
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)

src/sagemaker/pytorch/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
241241
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
242242
self._upload_code(deploy_key_prefix, repack=self._is_mms_version())
243243
deploy_env = dict(self.env)
244-
deploy_env.update(self._framework_env_vars())
244+
deploy_env.update(self._script_mode_env_vars())
245245

246246
if self.model_server_workers:
247247
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)

src/sagemaker/sklearn/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
165165
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
166166
self._upload_code(key_prefix=deploy_key_prefix, repack=self.enable_network_isolation())
167167
deploy_env = dict(self.env)
168-
deploy_env.update(self._framework_env_vars())
168+
deploy_env.update(self._script_mode_env_vars())
169169

170170
if self.model_server_workers:
171171
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)

src/sagemaker/workflow/airflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,7 @@ def prepare_framework_container_def(model, instance_type, s3_operations):
549549
]
550550

551551
deploy_env = dict(model.env)
552-
deploy_env.update(model._framework_env_vars())
552+
deploy_env.update(model._script_mode_env_vars())
553553

554554
try:
555555
if model.model_server_workers:

src/sagemaker/xgboost/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
147147
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
148148
self._upload_code(key_prefix=deploy_key_prefix, repack=self.enable_network_isolation())
149149
deploy_env = dict(self.env)
150-
deploy_env.update(self._framework_env_vars())
150+
deploy_env.update(self._script_mode_env_vars())
151151

152152
if self.model_server_workers:
153153
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)

tests/unit/sagemaker/model/test_model.py

Lines changed: 105 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
14+
from unittest.mock import MagicMock
1415

1516
import pytest
1617
from mock import Mock, patch
1718

1819
import sagemaker
19-
from sagemaker.model import Model
20+
from sagemaker.model import FrameworkModel, Model
2021

2122
MODEL_DATA = "s3://bucket/model.tar.gz"
2223
MODEL_IMAGE = "mi"
@@ -27,10 +28,39 @@
2728
INSTANCE_TYPE = "ml.c4.4xlarge"
2829
ROLE = "some-role"
2930

31+
REGION = "us-west-2"
32+
BUCKET_NAME = "some-bucket-name"
33+
GIT_REPO = "https://github.com/aws/sagemaker-python-sdk.git"
34+
BRANCH = "test-branch-git-config"
35+
COMMIT = "ae15c9d7d5b97ea95ea451e4662ee43da3401d73"
36+
ENTRY_POINT_INFERENCE = "inference.py"
3037

31-
@pytest.fixture
38+
SCRIPT_URI = "s3://codebucket/someprefix/sourcedir.tar.gz"
39+
IMAGE_URI = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.9.0-gpu-py38"
40+
41+
42+
class DummyFrameworkModel(FrameworkModel):
43+
def __init__(self, **kwargs):
44+
super(DummyFrameworkModel, self).__init__(
45+
**kwargs,
46+
)
47+
48+
49+
@pytest.fixture()
3250
def sagemaker_session():
33-
return Mock()
51+
boto_mock = Mock(name="boto_session", region_name=REGION)
52+
sms = MagicMock(
53+
name="sagemaker_session",
54+
boto_session=boto_mock,
55+
boto_region_name=REGION,
56+
config=None,
57+
local_mode=False,
58+
s3_client=None,
59+
s3_resource=None,
60+
)
61+
sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)
62+
63+
return sms
3464

3565

3666
def test_prepare_container_def_with_model_data():
@@ -345,3 +375,75 @@ def test_delete_model_no_name(sagemaker_session):
345375
):
346376
model.delete_model()
347377
sagemaker_session.delete_model.assert_not_called()
378+
379+
380+
@patch("time.strftime", MagicMock(return_value=TIMESTAMP))
381+
@patch("sagemaker.utils.repack_model")
382+
def test_script_mode_model_same_calls_as_framework(repack_model, sagemaker_session):
383+
t = Model(
384+
entry_point=ENTRY_POINT_INFERENCE,
385+
role=ROLE,
386+
sagemaker_session=sagemaker_session,
387+
source_dir=SCRIPT_URI,
388+
image_uri=IMAGE_URI,
389+
model_data=MODEL_DATA,
390+
)
391+
t.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT)
392+
393+
assert len(sagemaker_session.create_model.call_args_list) == 1
394+
assert len(sagemaker_session.endpoint_from_production_variants.call_args_list) == 1
395+
assert len(repack_model.call_args_list) == 1
396+
397+
generic_model_create_model_args = sagemaker_session.create_model.call_args_list
398+
generic_model_endpoint_from_production_variants_args = (
399+
sagemaker_session.endpoint_from_production_variants.call_args_list
400+
)
401+
generic_model_repack_model_args = repack_model.call_args_list
402+
403+
sagemaker_session.create_model.reset_mock()
404+
sagemaker_session.endpoint_from_production_variants.reset_mock()
405+
repack_model.reset_mock()
406+
407+
t = DummyFrameworkModel(
408+
entry_point=ENTRY_POINT_INFERENCE,
409+
role=ROLE,
410+
sagemaker_session=sagemaker_session,
411+
source_dir=SCRIPT_URI,
412+
image_uri=IMAGE_URI,
413+
model_data=MODEL_DATA,
414+
)
415+
t.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT)
416+
417+
assert generic_model_create_model_args == sagemaker_session.create_model.call_args_list
418+
assert (
419+
generic_model_endpoint_from_production_variants_args
420+
== sagemaker_session.endpoint_from_production_variants.call_args_list
421+
)
422+
assert generic_model_repack_model_args == repack_model.call_args_list
423+
424+
425+
@patch("sagemaker.git_utils.git_clone_repo")
426+
@patch("sagemaker.model.fw_utils.tar_and_upload_dir")
427+
def test_git_support_succeed_model_class(tar_and_upload_dir, git_clone_repo, sagemaker_session):
428+
git_clone_repo.side_effect = lambda gitconfig, entrypoint, sourcedir, dependency: {
429+
"entry_point": "entry_point",
430+
"source_dir": "/tmp/repo_dir/source_dir",
431+
"dependencies": ["/tmp/repo_dir/foo", "/tmp/repo_dir/bar"],
432+
}
433+
entry_point = "entry_point"
434+
source_dir = "source_dir"
435+
dependencies = ["foo", "bar"]
436+
git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT}
437+
model = Model(
438+
sagemaker_session=sagemaker_session,
439+
entry_point=entry_point,
440+
source_dir=source_dir,
441+
dependencies=dependencies,
442+
git_config=git_config,
443+
image_uri=IMAGE_URI,
444+
)
445+
model.prepare_container_def(instance_type=INSTANCE_TYPE)
446+
git_clone_repo.assert_called_with(git_config, entry_point, source_dir, dependencies)
447+
assert model.entry_point == "entry_point"
448+
assert model.source_dir == "/tmp/repo_dir/source_dir"
449+
assert model.dependencies == ["/tmp/repo_dir/foo", "/tmp/repo_dir/bar"]

0 commit comments

Comments
 (0)