From c1157bd874ae4f00023bba65367310f98521b83c Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Wed, 9 Aug 2023 18:21:36 +0000 Subject: [PATCH 1/4] fix: tags for jumpstart model package models --- src/sagemaker/jumpstart/model.py | 37 ++++++++++++++++--- .../sagemaker/jumpstart/model/test_model.py | 7 +++- 2 files changed, 37 insertions(+), 7 deletions(-) diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index c6b663c0fa..8c5a94fc61 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -310,13 +310,32 @@ def _is_valid_model_id_hook(): super(JumpStartModel, self).__init__(**model_init_kwargs.to_kwargs_dict()) - def _create_sagemaker_model(self, *args, **kwargs): # pylint: disable=unused-argument + def _create_sagemaker_model( + self, + instance_type=None, + accelerator_type=None, + tags=None, + serverless_inference_config=None, + **kwargs, + ): """Create a SageMaker Model Entity Args: - args: Positional arguments coming from the caller. This class does not require - any so they are ignored. - + instance_type (str): The EC2 instance type that this Model will be + used for, this is only used to determine if the image needs GPU + support or not. + accelerator_type (str): Type of Elastic Inference accelerator to + attach to an endpoint for model loading and inference, for + example, 'ml.eia1.medium'. If not specified, no Elastic + Inference accelerator will be attached to the endpoint. + tags (List[dict[str, str]]): Optional. The list of tags to add to + the model. Example: >>> tags = [{'Key': 'tagname', 'Value': + 'tagvalue'}] For more information about tags, see + https://boto3.amazonaws.com/v1/documentation + /api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags + serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig): + Specifies configuration related to serverless endpoint. Instance type is + not provided in serverless inference. So this is used to find image URIs. kwargs: Keyword arguments coming from the caller. This class does not require any so they are ignored. """ @@ -347,10 +366,16 @@ def _create_sagemaker_model(self, *args, **kwargs): # pylint: disable=unused-ar container_def, vpc_config=self.vpc_config, enable_network_isolation=self.enable_network_isolation(), - tags=kwargs.get("tags"), + tags=tags, ) else: - super(JumpStartModel, self)._create_sagemaker_model(*args, **kwargs) + super(JumpStartModel, self)._create_sagemaker_model( + instance_type=instance_type, + accelerator_type=accelerator_type, + tags=tags, + serverless_inference_config=serverless_inference_config, + **kwargs, + ) def deploy( self, diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index c5197e5399..fb7698741e 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -589,7 +589,10 @@ def test_jumpstart_model_package_arn( model = JumpStartModel(model_id=model_id) - model.deploy() + tag = {"Key": "foo", "Value": "bar"} + tags = [tag] + + model.deploy(tags=tags) self.assertEqual( mock_session.return_value.create_model.call_args[0][2], @@ -599,6 +602,8 @@ def test_jumpstart_model_package_arn( }, ) + self.assertIn(tag, mock_session.return_value.create_model.call_args[1]["tags"]) + @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") From f64ca6dfeb99466844e2d6c50b3d2e12aeeeef10 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Thu, 10 Aug 2023 14:09:35 +0000 Subject: [PATCH 2/4] fix: tags for ModelPackage --- src/sagemaker/model.py | 5 ++++- tests/unit/sagemaker/model/test_model_package.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 7ccd7a220d..d0f833795c 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -1375,7 +1375,10 @@ def deploy( self._base_name = "-".join((self._base_name, compiled_model_suffix)) self._create_sagemaker_model( - instance_type, accelerator_type, tags, serverless_inference_config + instance_type=instance_type, + accelerator_type=accelerator_type, + tags=tags, + serverless_inference_config=serverless_inference_config, ) serverless_inference_config_dict = ( diff --git a/tests/unit/sagemaker/model/test_model_package.py b/tests/unit/sagemaker/model/test_model_package.py index ca29644603..4e5f4cd3e4 100644 --- a/tests/unit/sagemaker/model/test_model_package.py +++ b/tests/unit/sagemaker/model/test_model_package.py @@ -197,7 +197,7 @@ def test_create_sagemaker_model_include_tags(sagemaker_session): sagemaker_session=sagemaker_session, ) - model_package._create_sagemaker_model(tags=tags) + model_package.deploy(tags=tags, instance_type="ml.p2.xlarge", initial_instance_count=1) sagemaker_session.create_model.assert_called_with( model_name, From 41a21135139f35ec49e81a6d85067e5a367d2ea0 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Thu, 10 Aug 2023 22:07:54 +0000 Subject: [PATCH 3/4] chore: improve docstring --- src/sagemaker/jumpstart/model.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 8c5a94fc61..71bf0058ca 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -321,21 +321,23 @@ def _create_sagemaker_model( """Create a SageMaker Model Entity Args: - instance_type (str): The EC2 instance type that this Model will be + instance_type (str): Optional. The EC2 instance type that this Model will be used for, this is only used to determine if the image needs GPU - support or not. - accelerator_type (str): Type of Elastic Inference accelerator to + support or not. (Default: None). + accelerator_type (str): Optional. Type of Elastic Inference accelerator to attach to an endpoint for model loading and inference, for example, 'ml.eia1.medium'. If not specified, no Elastic - Inference accelerator will be attached to the endpoint. + Inference accelerator will be attached to the endpoint. (Default: None). tags (List[dict[str, str]]): Optional. The list of tags to add to the model. Example: >>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}] For more information about tags, see https://boto3.amazonaws.com/v1/documentation /api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags + (Default: None). serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig): - Specifies configuration related to serverless endpoint. Instance type is + Optional. Specifies configuration related to serverless endpoint. Instance type is not provided in serverless inference. So this is used to find image URIs. + (Default: None). kwargs: Keyword arguments coming from the caller. This class does not require any so they are ignored. """ From e4fb3bcc73353e9c6d963e7bbb0bcf2356d0ec32 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Fri, 11 Aug 2023 17:15:10 +0000 Subject: [PATCH 4/4] fix: unit tests --- tests/unit/sagemaker/model/test_deploy.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/tests/unit/sagemaker/model/test_deploy.py b/tests/unit/sagemaker/model/test_deploy.py index a66e5ee5a3..d872ac3f7a 100644 --- a/tests/unit/sagemaker/model/test_deploy.py +++ b/tests/unit/sagemaker/model/test_deploy.py @@ -159,7 +159,12 @@ def test_deploy_accelerator_type( accelerator_type=ACCELERATOR_TYPE, ) - create_sagemaker_model.assert_called_with(INSTANCE_TYPE, ACCELERATOR_TYPE, None, None) + create_sagemaker_model.assert_called_with( + instance_type=INSTANCE_TYPE, + accelerator_type=ACCELERATOR_TYPE, + tags=None, + serverless_inference_config=None, + ) production_variant.assert_called_with( MODEL_NAME, INSTANCE_TYPE, @@ -271,7 +276,12 @@ def test_deploy_tags(create_sagemaker_model, production_variant, name_from_base, tags = [{"Key": "ModelName", "Value": "TestModel"}] model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT, tags=tags) - create_sagemaker_model.assert_called_with(INSTANCE_TYPE, None, tags, None) + create_sagemaker_model.assert_called_with( + instance_type=INSTANCE_TYPE, + accelerator_type=None, + tags=tags, + serverless_inference_config=None, + ) sagemaker_session.endpoint_from_production_variants.assert_called_with( name=ENDPOINT_NAME, production_variants=[BASE_PRODUCTION_VARIANT], @@ -463,7 +473,12 @@ def test_deploy_serverless_inference(production_variant, create_sagemaker_model, serverless_inference_config=serverless_inference_config, ) - create_sagemaker_model.assert_called_with(None, None, None, serverless_inference_config) + create_sagemaker_model.assert_called_with( + instance_type=None, + accelerator_type=None, + tags=None, + serverless_inference_config=serverless_inference_config, + ) production_variant.assert_called_with( MODEL_NAME, None,