diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index fd26e1ae46..055bbb2716 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -116,9 +116,9 @@ def __init__( self.predictor_cls = predictor_cls self.env = env or {} self.name = name + self._base_name = None self.vpc_config = vpc_config self.sagemaker_session = sagemaker_session - self._model_name = None self.endpoint_name = None self._is_compiled_model = False self._enable_network_isolation = enable_network_isolation @@ -184,7 +184,10 @@ def _create_sagemaker_model(self, instance_type=None, accelerator_type=None, tag /api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags """ container_def = self.prepare_container_def(instance_type, accelerator_type=accelerator_type) - self.name = self.name or utils.name_from_image(container_def["Image"]) + + self._ensure_base_name_if_needed(container_def["Image"]) + self._set_model_name_if_needed() + enable_network_isolation = self.enable_network_isolation() self._init_sagemaker_session_if_does_not_exist(instance_type) @@ -197,6 +200,16 @@ def _create_sagemaker_model(self, instance_type=None, accelerator_type=None, tag tags=tags, ) + def _ensure_base_name_if_needed(self, image): + """Create a base name from the image if there is no model name provided.""" + if self.name is None: + self._base_name = self._base_name or utils.base_name_from_image(image) + + def _set_model_name_if_needed(self): + """Generate a new model name if ``self._base_name`` is present.""" + if self._base_name: + self.name = utils.name_from_base(self._base_name) + def _framework(self): """Placeholder docstring""" return getattr(self, "__framework_name__", None) @@ -471,10 +484,9 @@ def deploy( compiled_model_suffix = "-".join(instance_type.split(".")[:-1]) if self._is_compiled_model: - name_prefix = self.name or utils.name_from_image( - self.image, max_length=(62 - len(compiled_model_suffix)) - ) - self.name = "{}-{}".format(name_prefix, compiled_model_suffix) + self._ensure_base_name_if_needed(self.image) + if self._base_name is not None: + self._base_name = "-".join((self._base_name, compiled_model_suffix)) self._create_sagemaker_model(instance_type, accelerator_type, tags) production_variant = sagemaker.production_variant( @@ -483,9 +495,10 @@ def deploy( if endpoint_name: self.endpoint_name = endpoint_name else: - self.endpoint_name = self.name - if self._is_compiled_model and not self.endpoint_name.endswith(compiled_model_suffix): - self.endpoint_name += compiled_model_suffix + base_endpoint_name = self._base_name or utils.base_from_name(self.name) + if self._is_compiled_model and not base_endpoint_name.endswith(compiled_model_suffix): + base_endpoint_name = "-".join((base_endpoint_name, compiled_model_suffix)) + self.endpoint_name = utils.name_from_base(base_endpoint_name) data_capture_config_dict = None if data_capture_config is not None: @@ -568,7 +581,7 @@ def transformer( max_payload=max_payload, env=env, tags=tags, - base_transform_job_name=self.name, + base_transform_job_name=self._base_name or self.name, volume_kms_key=volume_kms_key, sagemaker_session=self.sagemaker_session, ) @@ -994,13 +1007,18 @@ def _create_sagemaker_model(self, *args, **kwargs): # pylint: disable=unused-ar if self.env != {}: container_def["Environment"] = self.env - model_package_short_name = model_package_name.split("/")[-1] - enable_network_isolation = self.enable_network_isolation() - self.name = self.name or utils.name_from_base(model_package_short_name) + self._ensure_base_name_if_needed(model_package_name.split("/")[-1]) + self._set_model_name_if_needed() + self.sagemaker_session.create_model( self.name, self.role, container_def, vpc_config=self.vpc_config, - enable_network_isolation=enable_network_isolation, + enable_network_isolation=self.enable_network_isolation(), ) + + def _ensure_base_name_if_needed(self, base_name): + """Set the base name if there is no model name provided.""" + if self.name is None: + self._base_name = base_name diff --git a/src/sagemaker/pipeline.py b/src/sagemaker/pipeline.py index d343eadefa..5409eb4067 100644 --- a/src/sagemaker/pipeline.py +++ b/src/sagemaker/pipeline.py @@ -60,7 +60,6 @@ def __init__( self.name = name self.vpc_config = vpc_config self.sagemaker_session = sagemaker_session - self._model_name = None self.endpoint_name = None def pipeline_container_def(self, instance_type): diff --git a/tests/unit/sagemaker/model/test_deploy.py b/tests/unit/sagemaker/model/test_deploy.py index 9b017fce28..fb5aa1e750 100644 --- a/tests/unit/sagemaker/model/test_deploy.py +++ b/tests/unit/sagemaker/model/test_deploy.py @@ -22,8 +22,9 @@ MODEL_DATA = "s3://bucket/model.tar.gz" MODEL_IMAGE = "mi" -TIMESTAMP = "2017-10-10-14-14-15" +TIMESTAMP = "2020-07-02-20-10-30-288" MODEL_NAME = "{}-{}".format(MODEL_IMAGE, TIMESTAMP) +ENDPOINT_NAME = "endpoint-{}".format(TIMESTAMP) ACCELERATOR_TYPE = "ml.eia.medium" INSTANCE_COUNT = 2 @@ -46,9 +47,8 @@ def sagemaker_session(): @patch("sagemaker.production_variant") @patch("sagemaker.model.Model.prepare_container_def") -@patch("sagemaker.utils.name_from_image") -def test_deploy(name_from_image, prepare_container_def, production_variant, sagemaker_session): - name_from_image.return_value = MODEL_NAME +@patch("sagemaker.utils.name_from_base", return_value=MODEL_NAME) +def test_deploy(name_from_base, prepare_container_def, production_variant, sagemaker_session): production_variant.return_value = BASE_PRODUCTION_VARIANT container_def = {"Image": MODEL_IMAGE, "Environment": {}, "ModelDataUrl": MODEL_DATA} @@ -57,7 +57,9 @@ def test_deploy(name_from_image, prepare_container_def, production_variant, sage model = Model(MODEL_IMAGE, MODEL_DATA, role=ROLE, sagemaker_session=sagemaker_session) model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT) - name_from_image.assert_called_with(MODEL_IMAGE) + name_from_base.assert_called_with(MODEL_IMAGE) + assert 2 == name_from_base.call_count + prepare_container_def.assert_called_with(INSTANCE_TYPE, accelerator_type=None) production_variant.assert_called_with( MODEL_NAME, INSTANCE_TYPE, INSTANCE_COUNT, accelerator_type=None @@ -77,9 +79,12 @@ def test_deploy(name_from_image, prepare_container_def, production_variant, sage ) +@patch("sagemaker.utils.name_from_base", return_value=ENDPOINT_NAME) @patch("sagemaker.model.Model._create_sagemaker_model") @patch("sagemaker.production_variant") -def test_deploy_accelerator_type(production_variant, create_sagemaker_model, sagemaker_session): +def test_deploy_accelerator_type( + production_variant, create_sagemaker_model, name_from_base, sagemaker_session +): model = Model( MODEL_IMAGE, MODEL_DATA, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session ) @@ -100,7 +105,7 @@ def test_deploy_accelerator_type(production_variant, create_sagemaker_model, sag ) sagemaker_session.endpoint_from_production_variants.assert_called_with( - name=MODEL_NAME, + name=ENDPOINT_NAME, production_variants=[production_variant_result], tags=None, kms_key=None, @@ -109,7 +114,6 @@ def test_deploy_accelerator_type(production_variant, create_sagemaker_model, sag ) -@patch("sagemaker.utils.name_from_image", Mock()) @patch("sagemaker.model.Model._create_sagemaker_model", Mock()) @patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT) def test_deploy_endpoint_name(sagemaker_session): @@ -122,6 +126,7 @@ def test_deploy_endpoint_name(sagemaker_session): initial_instance_count=INSTANCE_COUNT, ) + assert endpoint_name == model.endpoint_name sagemaker_session.endpoint_from_production_variants.assert_called_with( name=endpoint_name, production_variants=[BASE_PRODUCTION_VARIANT], @@ -132,9 +137,57 @@ def test_deploy_endpoint_name(sagemaker_session): ) +@patch("sagemaker.model.Model._create_sagemaker_model", Mock()) +@patch("sagemaker.utils.name_from_base") +@patch("sagemaker.utils.base_from_name") +@patch("sagemaker.production_variant") +def test_deploy_generates_endpoint_name_each_time_from_model_name( + production_variant, base_from_name, name_from_base, sagemaker_session +): + model = Model( + MODEL_IMAGE, MODEL_DATA, name=MODEL_NAME, role=ROLE, sagemaker_session=sagemaker_session + ) + + model.deploy( + instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT, + ) + model.deploy( + instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT, + ) + + base_from_name.assert_called_with(MODEL_NAME) + name_from_base.assert_called_with(base_from_name.return_value) + assert 2 == name_from_base.call_count + + +@patch("sagemaker.model.Model._create_sagemaker_model", Mock()) +@patch("sagemaker.utils.name_from_base") +@patch("sagemaker.utils.base_from_name") +@patch("sagemaker.production_variant") +def test_deploy_generates_endpoint_name_each_time_from_base_name( + production_variant, base_from_name, name_from_base, sagemaker_session +): + model = Model(MODEL_IMAGE, MODEL_DATA, role=ROLE, sagemaker_session=sagemaker_session) + + base_name = "foo" + model._base_name = base_name + + model.deploy( + instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT, + ) + model.deploy( + instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT, + ) + + base_from_name.assert_not_called() + name_from_base.assert_called_with(base_name) + assert 2 == name_from_base.call_count + + +@patch("sagemaker.utils.name_from_base", return_value=ENDPOINT_NAME) @patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT) @patch("sagemaker.model.Model._create_sagemaker_model") -def test_deploy_tags(create_sagemaker_model, production_variant, sagemaker_session): +def test_deploy_tags(create_sagemaker_model, production_variant, name_from_base, sagemaker_session): model = Model( MODEL_IMAGE, MODEL_DATA, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session ) @@ -144,7 +197,7 @@ def test_deploy_tags(create_sagemaker_model, production_variant, sagemaker_sessi create_sagemaker_model.assert_called_with(INSTANCE_TYPE, None, tags) sagemaker_session.endpoint_from_production_variants.assert_called_with( - name=MODEL_NAME, + name=ENDPOINT_NAME, production_variants=[BASE_PRODUCTION_VARIANT], tags=tags, kms_key=None, @@ -154,8 +207,9 @@ def test_deploy_tags(create_sagemaker_model, production_variant, sagemaker_sessi @patch("sagemaker.model.Model._create_sagemaker_model", Mock()) +@patch("sagemaker.utils.name_from_base", return_value=ENDPOINT_NAME) @patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT) -def test_deploy_kms_key(production_variant, sagemaker_session): +def test_deploy_kms_key(production_variant, name_from_base, sagemaker_session): model = Model( MODEL_IMAGE, MODEL_DATA, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session ) @@ -164,7 +218,7 @@ def test_deploy_kms_key(production_variant, sagemaker_session): model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT, kms_key=key) sagemaker_session.endpoint_from_production_variants.assert_called_with( - name=MODEL_NAME, + name=ENDPOINT_NAME, production_variants=[BASE_PRODUCTION_VARIANT], tags=None, kms_key=key, @@ -174,8 +228,9 @@ def test_deploy_kms_key(production_variant, sagemaker_session): @patch("sagemaker.model.Model._create_sagemaker_model", Mock()) +@patch("sagemaker.utils.name_from_base", return_value=ENDPOINT_NAME) @patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT) -def test_deploy_async(production_variant, sagemaker_session): +def test_deploy_async(production_variant, name_from_base, sagemaker_session): model = Model( MODEL_IMAGE, MODEL_DATA, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session ) @@ -183,7 +238,7 @@ def test_deploy_async(production_variant, sagemaker_session): model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT, wait=False) sagemaker_session.endpoint_from_production_variants.assert_called_with( - name=MODEL_NAME, + name=ENDPOINT_NAME, production_variants=[BASE_PRODUCTION_VARIANT], tags=None, kms_key=None, @@ -193,8 +248,9 @@ def test_deploy_async(production_variant, sagemaker_session): @patch("sagemaker.model.Model._create_sagemaker_model", Mock()) +@patch("sagemaker.utils.name_from_base", return_value=ENDPOINT_NAME) @patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT) -def test_deploy_data_capture_config(production_variant, sagemaker_session): +def test_deploy_data_capture_config(production_variant, name_from_base, sagemaker_session): model = Model( MODEL_IMAGE, MODEL_DATA, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session ) @@ -210,7 +266,7 @@ def test_deploy_data_capture_config(production_variant, sagemaker_session): data_capture_config._to_request_dict.assert_called_with() sagemaker_session.endpoint_from_production_variants.assert_called_with( - name=MODEL_NAME, + name=ENDPOINT_NAME, production_variants=[BASE_PRODUCTION_VARIANT], tags=None, kms_key=None, diff --git a/tests/unit/sagemaker/model/test_model.py b/tests/unit/sagemaker/model/test_model.py index 1627437616..284745d3e4 100644 --- a/tests/unit/sagemaker/model/test_model.py +++ b/tests/unit/sagemaker/model/test_model.py @@ -63,37 +63,30 @@ def test_model_enable_network_isolation(): @patch("sagemaker.model.Model.prepare_container_def") -@patch("sagemaker.utils.name_from_image") -def test_create_sagemaker_model(name_from_image, prepare_container_def, sagemaker_session): - name_from_image.return_value = MODEL_NAME - +def test_create_sagemaker_model(prepare_container_def, sagemaker_session): container_def = {"Image": MODEL_IMAGE, "Environment": {}, "ModelDataUrl": MODEL_DATA} prepare_container_def.return_value = container_def - model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=sagemaker_session) + model = Model(MODEL_DATA, MODEL_IMAGE, name=MODEL_NAME, sagemaker_session=sagemaker_session) model._create_sagemaker_model() prepare_container_def.assert_called_with(None, accelerator_type=None) - name_from_image.assert_called_with(MODEL_IMAGE) - sagemaker_session.create_model.assert_called_with( MODEL_NAME, None, container_def, vpc_config=None, enable_network_isolation=False, tags=None ) -@patch("sagemaker.utils.name_from_image", Mock()) @patch("sagemaker.model.Model.prepare_container_def") def test_create_sagemaker_model_instance_type(prepare_container_def, sagemaker_session): - model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=sagemaker_session) + model = Model(MODEL_DATA, MODEL_IMAGE, name=MODEL_NAME, sagemaker_session=sagemaker_session) model._create_sagemaker_model(INSTANCE_TYPE) prepare_container_def.assert_called_with(INSTANCE_TYPE, accelerator_type=None) -@patch("sagemaker.utils.name_from_image", Mock()) @patch("sagemaker.model.Model.prepare_container_def") def test_create_sagemaker_model_accelerator_type(prepare_container_def, sagemaker_session): - model = Model(MODEL_IMAGE, MODEL_DATA, sagemaker_session=sagemaker_session) + model = Model(MODEL_IMAGE, MODEL_DATA, name=MODEL_NAME, sagemaker_session=sagemaker_session) accelerator_type = "ml.eia.medium" model._create_sagemaker_model(INSTANCE_TYPE, accelerator_type=accelerator_type) @@ -102,14 +95,11 @@ def test_create_sagemaker_model_accelerator_type(prepare_container_def, sagemake @patch("sagemaker.model.Model.prepare_container_def") -@patch("sagemaker.utils.name_from_image") -def test_create_sagemaker_model_tags(name_from_image, prepare_container_def, sagemaker_session): +def test_create_sagemaker_model_tags(prepare_container_def, sagemaker_session): container_def = {"Image": MODEL_IMAGE, "Environment": {}, "ModelDataUrl": MODEL_DATA} prepare_container_def.return_value = container_def - name_from_image.return_value = MODEL_NAME - - model = Model(MODEL_IMAGE, MODEL_DATA, sagemaker_session=sagemaker_session) + model = Model(MODEL_IMAGE, MODEL_DATA, name=MODEL_NAME, sagemaker_session=sagemaker_session) tags = {"Key": "foo", "Value": "bar"} model._create_sagemaker_model(INSTANCE_TYPE, tags=tags) @@ -120,9 +110,10 @@ def test_create_sagemaker_model_tags(name_from_image, prepare_container_def, sag @patch("sagemaker.model.Model.prepare_container_def") -@patch("sagemaker.utils.name_from_image") +@patch("sagemaker.utils.name_from_base") +@patch("sagemaker.utils.base_name_from_image") def test_create_sagemaker_model_optional_model_params( - name_from_image, prepare_container_def, sagemaker_session + base_name_from_image, name_from_base, prepare_container_def, sagemaker_session ): container_def = {"Image": MODEL_IMAGE, "Environment": {}, "ModelDataUrl": MODEL_DATA} prepare_container_def.return_value = container_def @@ -140,7 +131,8 @@ def test_create_sagemaker_model_optional_model_params( ) model._create_sagemaker_model(INSTANCE_TYPE) - name_from_image.assert_not_called() + base_name_from_image.assert_not_called() + name_from_base.assert_not_called() sagemaker_session.create_model.assert_called_with( MODEL_NAME, @@ -152,6 +144,44 @@ def test_create_sagemaker_model_optional_model_params( ) +@patch("sagemaker.model.Model.prepare_container_def") +@patch("sagemaker.utils.name_from_base", return_value=MODEL_NAME) +@patch("sagemaker.utils.base_name_from_image") +def test_create_sagemaker_model_generates_model_name( + base_name_from_image, name_from_base, prepare_container_def, sagemaker_session +): + container_def = {"Image": MODEL_IMAGE, "Environment": {}, "ModelDataUrl": MODEL_DATA} + prepare_container_def.return_value = container_def + + model = Model(MODEL_IMAGE, MODEL_DATA, sagemaker_session=sagemaker_session,) + model._create_sagemaker_model(INSTANCE_TYPE) + + base_name_from_image.assert_called_with(MODEL_IMAGE) + name_from_base.assert_called_with(base_name_from_image.return_value) + + sagemaker_session.create_model.assert_called_with( + MODEL_NAME, None, container_def, vpc_config=None, enable_network_isolation=False, tags=None, + ) + + +@patch("sagemaker.model.Model.prepare_container_def") +@patch("sagemaker.utils.name_from_base", return_value=MODEL_NAME) +@patch("sagemaker.utils.base_name_from_image") +def test_create_sagemaker_model_generates_model_name_each_time( + base_name_from_image, name_from_base, prepare_container_def, sagemaker_session +): + container_def = {"Image": MODEL_IMAGE, "Environment": {}, "ModelDataUrl": MODEL_DATA} + prepare_container_def.return_value = container_def + + model = Model(MODEL_IMAGE, MODEL_DATA, sagemaker_session=sagemaker_session,) + model._create_sagemaker_model(INSTANCE_TYPE) + model._create_sagemaker_model(INSTANCE_TYPE) + + base_name_from_image.assert_called_once_with(MODEL_IMAGE) + name_from_base.assert_called_with(base_name_from_image.return_value) + assert 2 == name_from_base.call_count + + @patch("sagemaker.session.Session") @patch("sagemaker.local.LocalSession") def test_create_sagemaker_model_creates_correct_session(local_session, session): @@ -238,8 +268,8 @@ def test_model_create_transformer_optional_params(create_sagemaker_model, sagema assert transformer.volume_kms_key == kms_key -@patch("sagemaker.model.Model._create_sagemaker_model") -def test_model_create_transformer_network_isolation(create_sagemaker_model, sagemaker_session): +@patch("sagemaker.model.Model._create_sagemaker_model", Mock()) +def test_model_create_transformer_network_isolation(sagemaker_session): model = Model( MODEL_IMAGE, MODEL_DATA, sagemaker_session=sagemaker_session, enable_network_isolation=True ) @@ -248,6 +278,17 @@ def test_model_create_transformer_network_isolation(create_sagemaker_model, sage assert transformer.env is None +@patch("sagemaker.model.Model._create_sagemaker_model", Mock()) +def test_model_create_transformer_base_name(sagemaker_session): + model = Model(MODEL_IMAGE, MODEL_DATA, sagemaker_session=sagemaker_session) + + base_name = "foo" + model._base_name = base_name + + transformer = model.transformer(1, "ml.m4.xlarge") + assert base_name == transformer.base_transform_job_name + + @patch("sagemaker.session.Session") @patch("sagemaker.local.LocalSession") def test_transformer_creates_correct_session(local_session, session): diff --git a/tests/unit/sagemaker/model/test_model_package.py b/tests/unit/sagemaker/model/test_model_package.py index fd9dfc1471..b9caa92dec 100644 --- a/tests/unit/sagemaker/model/test_model_package.py +++ b/tests/unit/sagemaker/model/test_model_package.py @@ -55,14 +55,15 @@ @pytest.fixture def sagemaker_session(): - return Mock() - - -def test_model_package_enable_network_isolation_with_no_product_id(sagemaker_session): - sagemaker_session.sagemaker_client.describe_model_package = Mock( + session = Mock() + session.sagemaker_client.describe_model_package = Mock( return_value=DESCRIBE_MODEL_PACKAGE_RESPONSE ) + return session + + +def test_model_package_enable_network_isolation_with_no_product_id(sagemaker_session): model_package = ModelPackage( role="role", model_package_arn="my-model-package", sagemaker_session=sagemaker_session ) @@ -88,12 +89,63 @@ def test_model_package_enable_network_isolation_with_product_id(sagemaker_sessio assert model_package.enable_network_isolation() is True -@patch("sagemaker.model.ModelPackage._create_sagemaker_model", Mock()) -def test_model_package_create_transformer(sagemaker_session): - sagemaker_session.sagemaker_client.describe_model_package = Mock( - return_value=DESCRIBE_MODEL_PACKAGE_RESPONSE +@patch("sagemaker.utils.name_from_base") +def test_create_sagemaker_model_uses_model_name(name_from_base, sagemaker_session): + model_name = "my-model" + model_package_name = "my-model-package" + + model_package = ModelPackage( + role="role", + name=model_name, + model_package_arn=model_package_name, + sagemaker_session=sagemaker_session, + ) + + model_package._create_sagemaker_model() + + assert model_name == model_package.name + name_from_base.assert_not_called() + + sagemaker_session.create_model.assert_called_with( + model_name, + "role", + {"ModelPackageName": model_package_name}, + vpc_config=None, + enable_network_isolation=False, ) + +@patch("sagemaker.utils.name_from_base") +def test_create_sagemaker_model_generates_model_name(name_from_base, sagemaker_session): + model_package_name = "my-model-package" + + model_package = ModelPackage( + role="role", model_package_arn=model_package_name, sagemaker_session=sagemaker_session + ) + + model_package._create_sagemaker_model() + + name_from_base.assert_called_with(model_package_name) + assert name_from_base.return_value == model_package.name + + +@patch("sagemaker.utils.name_from_base") +def test_create_sagemaker_model_generates_model_name_each_time(name_from_base, sagemaker_session): + model_package_name = "my-model-package" + + model_package = ModelPackage( + role="role", model_package_arn=model_package_name, sagemaker_session=sagemaker_session + ) + + model_package._create_sagemaker_model() + model_package._create_sagemaker_model() + + name_from_base.assert_called_with(model_package_name) + assert 2 == name_from_base.call_count + + +@patch("sagemaker.model.ModelPackage._create_sagemaker_model", Mock()) +def test_model_package_create_transformer(sagemaker_session): model_package = ModelPackage( role="role", model_package_arn="my-model-package", sagemaker_session=sagemaker_session ) diff --git a/tests/unit/sagemaker/model/test_neo.py b/tests/unit/sagemaker/model/test_neo.py index 46d1b0aab3..a7906ef9cf 100644 --- a/tests/unit/sagemaker/model/test_neo.py +++ b/tests/unit/sagemaker/model/test_neo.py @@ -37,7 +37,7 @@ def sagemaker_session(): def _create_model(sagemaker_session=None): - return Model(MODEL_IMAGE, MODEL_DATA, sagemaker_session=sagemaker_session) + return Model(MODEL_IMAGE, MODEL_DATA, role="role", sagemaker_session=sagemaker_session) def test_compile_model_for_inferentia(sagemaker_session): @@ -212,14 +212,33 @@ def test_check_neo_region(sagemaker_session): assert (region_name in NEO_REGION_LIST) is model.check_neo_region(region_name) -def test_deploy_valid_model_name(sagemaker_session): - model = Model( - image="long-base-name-that-is-over-the-63-character-limit-for-model-names", - model_data=MODEL_DATA, - role="role", - sagemaker_session=sagemaker_session, - ) +def test_deploy_honors_provided_model_name(sagemaker_session): + model = _create_model(sagemaker_session) model._is_compiled_model = True + model_name = "foo" + model.name = model_name + + model.deploy(1, "ml.c4.xlarge") + assert model_name == model.name + + +def test_deploy_add_compiled_model_suffix_to_generated_resource_names(sagemaker_session): + model = _create_model(sagemaker_session) + model._is_compiled_model = True + + model.deploy(1, "ml.c4.xlarge") + assert model.name.startswith("mi-ml-c4") + assert model.endpoint_name.startswith("mi-ml-c4") + + +@patch("sagemaker.model.Model._create_sagemaker_model", Mock()) +def test_deploy_add_compiled_model_suffix_to_endpoint_name_from_model_name(sagemaker_session): + model = _create_model(sagemaker_session) + model._is_compiled_model = True + + model_name = "foo" + model.name = model_name + model.deploy(1, "ml.c4.xlarge") - assert len(model.name) <= 63 + assert model.endpoint_name.startswith("{}-ml-c4".format(model_name))