Skip to content

breaking: create new inference resources during model.deploy() and model.transformer() #1666

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 32 additions & 14 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ def __init__(
self.predictor_cls = predictor_cls
self.env = env or {}
self.name = name
self._base_name = None
self.vpc_config = vpc_config
self.sagemaker_session = sagemaker_session
self._model_name = None
self.endpoint_name = None
self._is_compiled_model = False
self._enable_network_isolation = enable_network_isolation
Expand Down Expand Up @@ -184,7 +184,10 @@ def _create_sagemaker_model(self, instance_type=None, accelerator_type=None, tag
/api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
"""
container_def = self.prepare_container_def(instance_type, accelerator_type=accelerator_type)
self.name = self.name or utils.name_from_image(container_def["Image"])

self._ensure_base_name_if_needed(container_def["Image"])
self._set_model_name_if_needed()

enable_network_isolation = self.enable_network_isolation()

self._init_sagemaker_session_if_does_not_exist(instance_type)
Expand All @@ -197,6 +200,16 @@ def _create_sagemaker_model(self, instance_type=None, accelerator_type=None, tag
tags=tags,
)

def _ensure_base_name_if_needed(self, image):
"""Create a base name from the image if there is no model name provided."""
if self.name is None:
self._base_name = self._base_name or utils.base_name_from_image(image)

def _set_model_name_if_needed(self):
"""Generate a new model name if ``self._base_name`` is present."""
if self._base_name:
self.name = utils.name_from_base(self._base_name)

def _framework(self):
"""Placeholder docstring"""
return getattr(self, "__framework_name__", None)
Expand Down Expand Up @@ -471,10 +484,9 @@ def deploy(

compiled_model_suffix = "-".join(instance_type.split(".")[:-1])
if self._is_compiled_model:
name_prefix = self.name or utils.name_from_image(
self.image, max_length=(62 - len(compiled_model_suffix))
)
self.name = "{}-{}".format(name_prefix, compiled_model_suffix)
self._ensure_base_name_if_needed(self.image)
if self._base_name is not None:
self._base_name = "-".join((self._base_name, compiled_model_suffix))

self._create_sagemaker_model(instance_type, accelerator_type, tags)
production_variant = sagemaker.production_variant(
Expand All @@ -483,9 +495,10 @@ def deploy(
if endpoint_name:
self.endpoint_name = endpoint_name
else:
self.endpoint_name = self.name
if self._is_compiled_model and not self.endpoint_name.endswith(compiled_model_suffix):
self.endpoint_name += compiled_model_suffix
base_endpoint_name = self._base_name or utils.base_from_name(self.name)
if self._is_compiled_model and not base_endpoint_name.endswith(compiled_model_suffix):
base_endpoint_name = "-".join((base_endpoint_name, compiled_model_suffix))
self.endpoint_name = utils.name_from_base(base_endpoint_name)

data_capture_config_dict = None
if data_capture_config is not None:
Expand Down Expand Up @@ -568,7 +581,7 @@ def transformer(
max_payload=max_payload,
env=env,
tags=tags,
base_transform_job_name=self.name,
base_transform_job_name=self._base_name or self.name,
volume_kms_key=volume_kms_key,
sagemaker_session=self.sagemaker_session,
)
Expand Down Expand Up @@ -994,13 +1007,18 @@ def _create_sagemaker_model(self, *args, **kwargs): # pylint: disable=unused-ar
if self.env != {}:
container_def["Environment"] = self.env

model_package_short_name = model_package_name.split("/")[-1]
enable_network_isolation = self.enable_network_isolation()
self.name = self.name or utils.name_from_base(model_package_short_name)
self._ensure_base_name_if_needed(model_package_name.split("/")[-1])
self._set_model_name_if_needed()

self.sagemaker_session.create_model(
self.name,
self.role,
container_def,
vpc_config=self.vpc_config,
enable_network_isolation=enable_network_isolation,
enable_network_isolation=self.enable_network_isolation(),
)

def _ensure_base_name_if_needed(self, base_name):
"""Set the base name if there is no model name provided."""
if self.name is None:
self._base_name = base_name
1 change: 0 additions & 1 deletion src/sagemaker/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def __init__(
self.name = name
self.vpc_config = vpc_config
self.sagemaker_session = sagemaker_session
self._model_name = None
self.endpoint_name = None

def pipeline_container_def(self, instance_type):
Expand Down
88 changes: 72 additions & 16 deletions tests/unit/sagemaker/model/test_deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@

MODEL_DATA = "s3://bucket/model.tar.gz"
MODEL_IMAGE = "mi"
TIMESTAMP = "2017-10-10-14-14-15"
TIMESTAMP = "2020-07-02-20-10-30-288"
MODEL_NAME = "{}-{}".format(MODEL_IMAGE, TIMESTAMP)
ENDPOINT_NAME = "endpoint-{}".format(TIMESTAMP)

ACCELERATOR_TYPE = "ml.eia.medium"
INSTANCE_COUNT = 2
Expand All @@ -46,9 +47,8 @@ def sagemaker_session():

@patch("sagemaker.production_variant")
@patch("sagemaker.model.Model.prepare_container_def")
@patch("sagemaker.utils.name_from_image")
def test_deploy(name_from_image, prepare_container_def, production_variant, sagemaker_session):
name_from_image.return_value = MODEL_NAME
@patch("sagemaker.utils.name_from_base", return_value=MODEL_NAME)
def test_deploy(name_from_base, prepare_container_def, production_variant, sagemaker_session):
production_variant.return_value = BASE_PRODUCTION_VARIANT

container_def = {"Image": MODEL_IMAGE, "Environment": {}, "ModelDataUrl": MODEL_DATA}
Expand All @@ -57,7 +57,9 @@ def test_deploy(name_from_image, prepare_container_def, production_variant, sage
model = Model(MODEL_IMAGE, MODEL_DATA, role=ROLE, sagemaker_session=sagemaker_session)
model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT)

name_from_image.assert_called_with(MODEL_IMAGE)
name_from_base.assert_called_with(MODEL_IMAGE)
assert 2 == name_from_base.call_count

prepare_container_def.assert_called_with(INSTANCE_TYPE, accelerator_type=None)
production_variant.assert_called_with(
MODEL_NAME, INSTANCE_TYPE, INSTANCE_COUNT, accelerator_type=None
Expand All @@ -77,9 +79,12 @@ def test_deploy(name_from_image, prepare_container_def, production_variant, sage
)


@patch("sagemaker.utils.name_from_base", return_value=ENDPOINT_NAME)
@patch("sagemaker.model.Model._create_sagemaker_model")
@patch("sagemaker.production_variant")
def test_deploy_accelerator_type(production_variant, create_sagemaker_model, sagemaker_session):
def test_deploy_accelerator_type(
production_variant, create_sagemaker_model, name_from_base, sagemaker_session
):
model = Model(
MODEL_IMAGE, MODEL_DATA, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session
)
Expand All @@ -100,7 +105,7 @@ def test_deploy_accelerator_type(production_variant, create_sagemaker_model, sag
)

sagemaker_session.endpoint_from_production_variants.assert_called_with(
name=MODEL_NAME,
name=ENDPOINT_NAME,
production_variants=[production_variant_result],
tags=None,
kms_key=None,
Expand All @@ -109,7 +114,6 @@ def test_deploy_accelerator_type(production_variant, create_sagemaker_model, sag
)


@patch("sagemaker.utils.name_from_image", Mock())
@patch("sagemaker.model.Model._create_sagemaker_model", Mock())
@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT)
def test_deploy_endpoint_name(sagemaker_session):
Expand All @@ -122,6 +126,7 @@ def test_deploy_endpoint_name(sagemaker_session):
initial_instance_count=INSTANCE_COUNT,
)

assert endpoint_name == model.endpoint_name
sagemaker_session.endpoint_from_production_variants.assert_called_with(
name=endpoint_name,
production_variants=[BASE_PRODUCTION_VARIANT],
Expand All @@ -132,9 +137,57 @@ def test_deploy_endpoint_name(sagemaker_session):
)


@patch("sagemaker.model.Model._create_sagemaker_model", Mock())
@patch("sagemaker.utils.name_from_base")
@patch("sagemaker.utils.base_from_name")
@patch("sagemaker.production_variant")
def test_deploy_generates_endpoint_name_each_time_from_model_name(
production_variant, base_from_name, name_from_base, sagemaker_session
):
model = Model(
MODEL_IMAGE, MODEL_DATA, name=MODEL_NAME, role=ROLE, sagemaker_session=sagemaker_session
)

model.deploy(
instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT,
)
model.deploy(
instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT,
)

base_from_name.assert_called_with(MODEL_NAME)
name_from_base.assert_called_with(base_from_name.return_value)
assert 2 == name_from_base.call_count


@patch("sagemaker.model.Model._create_sagemaker_model", Mock())
@patch("sagemaker.utils.name_from_base")
@patch("sagemaker.utils.base_from_name")
@patch("sagemaker.production_variant")
def test_deploy_generates_endpoint_name_each_time_from_base_name(
production_variant, base_from_name, name_from_base, sagemaker_session
):
model = Model(MODEL_IMAGE, MODEL_DATA, role=ROLE, sagemaker_session=sagemaker_session)

base_name = "foo"
model._base_name = base_name

model.deploy(
instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT,
)
model.deploy(
instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT,
)

base_from_name.assert_not_called()
name_from_base.assert_called_with(base_name)
assert 2 == name_from_base.call_count


@patch("sagemaker.utils.name_from_base", return_value=ENDPOINT_NAME)
@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT)
@patch("sagemaker.model.Model._create_sagemaker_model")
def test_deploy_tags(create_sagemaker_model, production_variant, sagemaker_session):
def test_deploy_tags(create_sagemaker_model, production_variant, name_from_base, sagemaker_session):
model = Model(
MODEL_IMAGE, MODEL_DATA, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session
)
Expand All @@ -144,7 +197,7 @@ def test_deploy_tags(create_sagemaker_model, production_variant, sagemaker_sessi

create_sagemaker_model.assert_called_with(INSTANCE_TYPE, None, tags)
sagemaker_session.endpoint_from_production_variants.assert_called_with(
name=MODEL_NAME,
name=ENDPOINT_NAME,
production_variants=[BASE_PRODUCTION_VARIANT],
tags=tags,
kms_key=None,
Expand All @@ -154,8 +207,9 @@ def test_deploy_tags(create_sagemaker_model, production_variant, sagemaker_sessi


@patch("sagemaker.model.Model._create_sagemaker_model", Mock())
@patch("sagemaker.utils.name_from_base", return_value=ENDPOINT_NAME)
@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT)
def test_deploy_kms_key(production_variant, sagemaker_session):
def test_deploy_kms_key(production_variant, name_from_base, sagemaker_session):
model = Model(
MODEL_IMAGE, MODEL_DATA, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session
)
Expand All @@ -164,7 +218,7 @@ def test_deploy_kms_key(production_variant, sagemaker_session):
model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT, kms_key=key)

sagemaker_session.endpoint_from_production_variants.assert_called_with(
name=MODEL_NAME,
name=ENDPOINT_NAME,
production_variants=[BASE_PRODUCTION_VARIANT],
tags=None,
kms_key=key,
Expand All @@ -174,16 +228,17 @@ def test_deploy_kms_key(production_variant, sagemaker_session):


@patch("sagemaker.model.Model._create_sagemaker_model", Mock())
@patch("sagemaker.utils.name_from_base", return_value=ENDPOINT_NAME)
@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT)
def test_deploy_async(production_variant, sagemaker_session):
def test_deploy_async(production_variant, name_from_base, sagemaker_session):
model = Model(
MODEL_IMAGE, MODEL_DATA, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session
)

model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT, wait=False)

sagemaker_session.endpoint_from_production_variants.assert_called_with(
name=MODEL_NAME,
name=ENDPOINT_NAME,
production_variants=[BASE_PRODUCTION_VARIANT],
tags=None,
kms_key=None,
Expand All @@ -193,8 +248,9 @@ def test_deploy_async(production_variant, sagemaker_session):


@patch("sagemaker.model.Model._create_sagemaker_model", Mock())
@patch("sagemaker.utils.name_from_base", return_value=ENDPOINT_NAME)
@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT)
def test_deploy_data_capture_config(production_variant, sagemaker_session):
def test_deploy_data_capture_config(production_variant, name_from_base, sagemaker_session):
model = Model(
MODEL_IMAGE, MODEL_DATA, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session
)
Expand All @@ -210,7 +266,7 @@ def test_deploy_data_capture_config(production_variant, sagemaker_session):

data_capture_config._to_request_dict.assert_called_with()
sagemaker_session.endpoint_from_production_variants.assert_called_with(
name=MODEL_NAME,
name=ENDPOINT_NAME,
production_variants=[BASE_PRODUCTION_VARIANT],
tags=None,
kms_key=None,
Expand Down
Loading