Skip to content

Commit 613ba6a

Browse files
committed
feat: script mode for model class
1 parent b09793a commit 613ba6a

File tree

9 files changed

+338
-121
lines changed

9 files changed

+338
-121
lines changed

src/sagemaker/chainer/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
168168
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
169169
self._upload_code(deploy_key_prefix)
170170
deploy_env = dict(self.env)
171-
deploy_env.update(self._framework_env_vars())
171+
deploy_env.update(self._script_mode_env_vars())
172172

173173
if self.model_server_workers:
174174
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)

src/sagemaker/huggingface/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
273273
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
274274
self._upload_code(deploy_key_prefix, repack=True)
275275
deploy_env = dict(self.env)
276-
deploy_env.update(self._framework_env_vars())
276+
deploy_env.update(self._script_mode_env_vars())
277277

278278
if self.model_server_workers:
279279
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)

src/sagemaker/model.py

Lines changed: 232 additions & 114 deletions
Large diffs are not rendered by default.

src/sagemaker/mxnet/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
244244
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
245245
self._upload_code(deploy_key_prefix, self._is_mms_version())
246246
deploy_env = dict(self.env)
247-
deploy_env.update(self._framework_env_vars())
247+
deploy_env.update(self._script_mode_env_vars())
248248

249249
if self.model_server_workers:
250250
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)

src/sagemaker/pytorch/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
241241
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
242242
self._upload_code(deploy_key_prefix, repack=self._is_mms_version())
243243
deploy_env = dict(self.env)
244-
deploy_env.update(self._framework_env_vars())
244+
deploy_env.update(self._script_mode_env_vars())
245245

246246
if self.model_server_workers:
247247
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)

src/sagemaker/sklearn/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
165165
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
166166
self._upload_code(key_prefix=deploy_key_prefix, repack=self.enable_network_isolation())
167167
deploy_env = dict(self.env)
168-
deploy_env.update(self._framework_env_vars())
168+
deploy_env.update(self._script_mode_env_vars())
169169

170170
if self.model_server_workers:
171171
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)

src/sagemaker/workflow/airflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,7 @@ def prepare_framework_container_def(model, instance_type, s3_operations):
549549
]
550550

551551
deploy_env = dict(model.env)
552-
deploy_env.update(model._framework_env_vars())
552+
deploy_env.update(model._script_mode_env_vars())
553553

554554
try:
555555
if model.model_server_workers:

src/sagemaker/xgboost/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
147147
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
148148
self._upload_code(key_prefix=deploy_key_prefix, repack=self.enable_network_isolation())
149149
deploy_env = dict(self.env)
150-
deploy_env.update(self._framework_env_vars())
150+
deploy_env.update(self._script_mode_env_vars())
151151

152152
if self.model_server_workers:
153153
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)

tests/unit/test_model.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
16+
import pytest
17+
from mock import MagicMock, Mock, patch
18+
from sagemaker.model import FrameworkModel, Model
19+
20+
21+
ENTRY_POINT_INFERENCE = "inference.py"
22+
REGION = "us-west-2"
23+
TIMESTAMP = "2017-11-06-14:14:15.671"
24+
BUCKET_NAME = "mybucket"
25+
INSTANCE_COUNT = 1
26+
INSTANCE_TYPE = "ml.p2.xlarge"
27+
ROLE = "DummyRole"
28+
SCRIPT_URI = "s3://codebucket/someprefix/sourcedir.tar.gz"
29+
IMAGE_URI = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.9.0-gpu-py38"
30+
MODEL_DATA = "s3://someprefix2/models/model.tar.gz"
31+
32+
33+
class DummyFrameworkModel(FrameworkModel):
34+
def __init__(self, **kwargs):
35+
super(DummyFrameworkModel, self).__init__(
36+
**kwargs,
37+
)
38+
39+
40+
@pytest.fixture()
41+
def sagemaker_session():
42+
boto_mock = Mock(name="boto_session", region_name=REGION)
43+
sms = MagicMock(
44+
name="sagemaker_session",
45+
boto_session=boto_mock,
46+
boto_region_name=REGION,
47+
config=None,
48+
local_mode=False,
49+
s3_client=None,
50+
s3_resource=None,
51+
)
52+
sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)
53+
54+
return sms
55+
56+
57+
@patch("time.strftime", MagicMock(return_value=TIMESTAMP))
58+
@patch("sagemaker.utils.repack_model")
59+
def test_script_mode_model_same_calls_as_framework(repack_model, sagemaker_session):
60+
t = Model(
61+
entry_point=ENTRY_POINT_INFERENCE,
62+
role=ROLE,
63+
sagemaker_session=sagemaker_session,
64+
source_dir=SCRIPT_URI,
65+
image_uri=IMAGE_URI,
66+
model_data=MODEL_DATA,
67+
)
68+
t.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT)
69+
70+
assert len(sagemaker_session.create_model.call_args_list) == 1
71+
assert len(sagemaker_session.endpoint_from_production_variants.call_args_list) == 1
72+
assert len(repack_model.call_args_list) == 1
73+
74+
generic_model_create_model_args = sagemaker_session.create_model.call_args_list
75+
generic_model_endpoint_from_production_variants_args = (
76+
sagemaker_session.endpoint_from_production_variants.call_args_list
77+
)
78+
generic_model_repack_model_args = repack_model.call_args_list
79+
80+
sagemaker_session.create_model.reset_mock()
81+
sagemaker_session.endpoint_from_production_variants.reset_mock()
82+
repack_model.reset_mock()
83+
84+
t = DummyFrameworkModel(
85+
entry_point=ENTRY_POINT_INFERENCE,
86+
role=ROLE,
87+
sagemaker_session=sagemaker_session,
88+
source_dir=SCRIPT_URI,
89+
image_uri=IMAGE_URI,
90+
model_data=MODEL_DATA,
91+
)
92+
t.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT)
93+
94+
assert generic_model_create_model_args == sagemaker_session.create_model.call_args_list
95+
assert (
96+
generic_model_endpoint_from_production_variants_args
97+
== sagemaker_session.endpoint_from_production_variants.call_args_list
98+
)
99+
assert generic_model_repack_model_args == repack_model.call_args_list

0 commit comments

Comments
 (0)