Skip to content

feature: script mode for model class #2841

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/sagemaker/chainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
self._upload_code(deploy_key_prefix)
deploy_env = dict(self.env)
deploy_env.update(self._framework_env_vars())
deploy_env.update(self._script_mode_env_vars())

if self.model_server_workers:
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/huggingface/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
self._upload_code(deploy_key_prefix, repack=True)
deploy_env = dict(self.env)
deploy_env.update(self._framework_env_vars())
deploy_env.update(self._script_mode_env_vars())

if self.model_server_workers:
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
Expand Down
354 changes: 239 additions & 115 deletions src/sagemaker/model.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/sagemaker/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
self._upload_code(deploy_key_prefix, self._is_mms_version())
deploy_env = dict(self.env)
deploy_env.update(self._framework_env_vars())
deploy_env.update(self._script_mode_env_vars())

if self.model_server_workers:
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
self._upload_code(deploy_key_prefix, repack=self._is_mms_version())
deploy_env = dict(self.env)
deploy_env.update(self._framework_env_vars())
deploy_env.update(self._script_mode_env_vars())

if self.model_server_workers:
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/sklearn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
self._upload_code(key_prefix=deploy_key_prefix, repack=self.enable_network_isolation())
deploy_env = dict(self.env)
deploy_env.update(self._framework_env_vars())
deploy_env.update(self._script_mode_env_vars())

if self.model_server_workers:
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/workflow/airflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ def prepare_framework_container_def(model, instance_type, s3_operations):
]

deploy_env = dict(model.env)
deploy_env.update(model._framework_env_vars())
deploy_env.update(model._script_mode_env_vars())

try:
if model.model_server_workers:
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/xgboost/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
self._upload_code(key_prefix=deploy_key_prefix, repack=self.enable_network_isolation())
deploy_env = dict(self.env)
deploy_env.update(self._framework_env_vars())
deploy_env.update(self._script_mode_env_vars())

if self.model_server_workers:
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
Expand Down
108 changes: 105 additions & 3 deletions tests/unit/sagemaker/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import
from unittest.mock import MagicMock

import pytest
from mock import Mock, patch

import sagemaker
from sagemaker.model import Model
from sagemaker.model import FrameworkModel, Model

MODEL_DATA = "s3://bucket/model.tar.gz"
MODEL_IMAGE = "mi"
Expand All @@ -27,10 +28,39 @@
INSTANCE_TYPE = "ml.c4.4xlarge"
ROLE = "some-role"

REGION = "us-west-2"
BUCKET_NAME = "some-bucket-name"
GIT_REPO = "https://github.com/aws/sagemaker-python-sdk.git"
BRANCH = "test-branch-git-config"
COMMIT = "ae15c9d7d5b97ea95ea451e4662ee43da3401d73"
ENTRY_POINT_INFERENCE = "inference.py"

@pytest.fixture
SCRIPT_URI = "s3://codebucket/someprefix/sourcedir.tar.gz"
IMAGE_URI = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.9.0-gpu-py38"


class DummyFrameworkModel(FrameworkModel):
def __init__(self, **kwargs):
super(DummyFrameworkModel, self).__init__(
**kwargs,
)


@pytest.fixture()
def sagemaker_session():
return Mock()
boto_mock = Mock(name="boto_session", region_name=REGION)
sms = MagicMock(
name="sagemaker_session",
boto_session=boto_mock,
boto_region_name=REGION,
config=None,
local_mode=False,
s3_client=None,
s3_resource=None,
)
sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)

return sms


def test_prepare_container_def_with_model_data():
Expand Down Expand Up @@ -345,3 +375,75 @@ def test_delete_model_no_name(sagemaker_session):
):
model.delete_model()
sagemaker_session.delete_model.assert_not_called()


@patch("time.strftime", MagicMock(return_value=TIMESTAMP))
@patch("sagemaker.utils.repack_model")
def test_script_mode_model_same_calls_as_framework(repack_model, sagemaker_session):
t = Model(
entry_point=ENTRY_POINT_INFERENCE,
role=ROLE,
sagemaker_session=sagemaker_session,
source_dir=SCRIPT_URI,
image_uri=IMAGE_URI,
model_data=MODEL_DATA,
)
t.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT)

assert len(sagemaker_session.create_model.call_args_list) == 1
assert len(sagemaker_session.endpoint_from_production_variants.call_args_list) == 1
assert len(repack_model.call_args_list) == 1

generic_model_create_model_args = sagemaker_session.create_model.call_args_list
generic_model_endpoint_from_production_variants_args = (
sagemaker_session.endpoint_from_production_variants.call_args_list
)
generic_model_repack_model_args = repack_model.call_args_list

sagemaker_session.create_model.reset_mock()
sagemaker_session.endpoint_from_production_variants.reset_mock()
repack_model.reset_mock()

t = DummyFrameworkModel(
entry_point=ENTRY_POINT_INFERENCE,
role=ROLE,
sagemaker_session=sagemaker_session,
source_dir=SCRIPT_URI,
image_uri=IMAGE_URI,
model_data=MODEL_DATA,
)
t.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT)

assert generic_model_create_model_args == sagemaker_session.create_model.call_args_list
assert (
generic_model_endpoint_from_production_variants_args
== sagemaker_session.endpoint_from_production_variants.call_args_list
)
assert generic_model_repack_model_args == repack_model.call_args_list


@patch("sagemaker.git_utils.git_clone_repo")
@patch("sagemaker.model.fw_utils.tar_and_upload_dir")
def test_git_support_succeed_model_class(tar_and_upload_dir, git_clone_repo, sagemaker_session):
git_clone_repo.side_effect = lambda gitconfig, entrypoint, sourcedir, dependency: {
"entry_point": "entry_point",
"source_dir": "/tmp/repo_dir/source_dir",
"dependencies": ["/tmp/repo_dir/foo", "/tmp/repo_dir/bar"],
}
entry_point = "entry_point"
source_dir = "source_dir"
dependencies = ["foo", "bar"]
git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT}
model = Model(
sagemaker_session=sagemaker_session,
entry_point=entry_point,
source_dir=source_dir,
dependencies=dependencies,
git_config=git_config,
image_uri=IMAGE_URI,
)
model.prepare_container_def(instance_type=INSTANCE_TYPE)
git_clone_repo.assert_called_with(git_config, entry_point, source_dir, dependencies)
assert model.entry_point == "entry_point"
assert model.source_dir == "/tmp/repo_dir/source_dir"
assert model.dependencies == ["/tmp/repo_dir/foo", "/tmp/repo_dir/bar"]