diff --git a/src/sagemaker/automl/automl.py b/src/sagemaker/automl/automl.py index eddd8de698..8fe5a9a939 100644 --- a/src/sagemaker/automl/automl.py +++ b/src/sagemaker/automl/automl.py @@ -350,6 +350,9 @@ def deploy( model_kms_key=None, predictor_cls=None, inference_response_keys=None, + volume_size=None, + model_data_download_timeout=None, + container_startup_health_check_timeout=None, ): """Deploy a candidate to a SageMaker Inference Pipeline. @@ -396,6 +399,16 @@ def deploy( function on the created endpoint name. inference_response_keys (list): List of keys for response content. The order of the keys will dictate the content order in the response. + volume_size (int): The size, in GB, of the ML storage volume attached to individual + inference instance associated with the production variant. Currenly only Amazon EBS + gp2 storage volumes are supported. + model_data_download_timeout (int): The timeout value, in seconds, to download and + extract model data from Amazon S3 to the individual inference instance associated + with this production variant. + container_startup_health_check_timeout (int): The timeout value, in seconds, for your + inference container to pass health check by SageMaker Hosting. For more information + about health check see: + https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-algo-ping-requests Returns: callable[string, sagemaker.session.Session] or ``None``: @@ -423,6 +436,9 @@ def deploy( kms_key=model_kms_key, tags=tags, wait=wait, + volume_size=volume_size, + model_data_download_timeout=model_data_download_timeout, + container_startup_health_check_timeout=container_startup_health_check_timeout, ) def _check_problem_type_and_job_objective(self, problem_type, job_objective): diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index bcffd70013..6f729267de 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -1304,6 +1304,9 @@ def deploy( tags=None, serverless_inference_config=None, async_inference_config=None, + volume_size=None, + model_data_download_timeout=None, + container_startup_health_check_timeout=None, **kwargs, ): """Deploy the trained model to an Amazon SageMaker endpoint. @@ -1371,6 +1374,16 @@ def deploy( For more information about tags, see https://boto3.amazonaws.com/v1/documentation\ /api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags + volume_size (int): The size, in GB, of the ML storage volume attached to individual + inference instance associated with the production variant. Currenly only Amazon EBS + gp2 storage volumes are supported. + model_data_download_timeout (int): The timeout value, in seconds, to download and + extract model data from Amazon S3 to the individual inference instance associated + with this production variant. + container_startup_health_check_timeout (int): The timeout value, in seconds, for your + inference container to pass health check by SageMaker Hosting. For more information + about health check see: + https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-algo-ping-requests **kwargs: Passed to invocation of ``create_model()``. Implementations may customize ``create_model()`` to accept ``**kwargs`` to customize model creation during deploy. @@ -1429,6 +1442,9 @@ def deploy( data_capture_config=data_capture_config, serverless_inference_config=serverless_inference_config, async_inference_config=async_inference_config, + volume_size=volume_size, + model_data_download_timeout=model_data_download_timeout, + container_startup_health_check_timeout=container_startup_health_check_timeout, ) def register( diff --git a/src/sagemaker/huggingface/model.py b/src/sagemaker/huggingface/model.py index a66c7e2389..6b93470c3a 100644 --- a/src/sagemaker/huggingface/model.py +++ b/src/sagemaker/huggingface/model.py @@ -206,6 +206,9 @@ def deploy( data_capture_config=None, async_inference_config=None, serverless_inference_config=None, + volume_size=None, + model_data_download_timeout=None, + container_startup_health_check_timeout=None, **kwargs, ): """Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``. @@ -269,6 +272,16 @@ def deploy( empty object passed through, will use pre-defined values in ``ServerlessInferenceConfig`` class to deploy serverless endpoint. Deploy an instance based endpoint if it's None. (default: None) + volume_size (int): The size, in GB, of the ML storage volume attached to individual + inference instance associated with the production variant. Currenly only Amazon EBS + gp2 storage volumes are supported. + model_data_download_timeout (int): The timeout value, in seconds, to download and + extract model data from Amazon S3 to the individual inference instance associated + with this production variant. + container_startup_health_check_timeout (int): The timeout value, in seconds, for your + inference container to pass health check by SageMaker Hosting. For more information + about health check see: + https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-algo-ping-requests Raises: ValueError: If arguments combination check failed in these circumstances: - If no role is specified or @@ -301,6 +314,9 @@ def deploy( data_capture_config, async_inference_config, serverless_inference_config, + volume_size=volume_size, + model_data_download_timeout=model_data_download_timeout, + container_startup_health_check_timeout=container_startup_health_check_timeout, ) def register( diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 821e9ae037..4fc0552d64 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -1029,6 +1029,9 @@ def deploy( data_capture_config=None, async_inference_config=None, serverless_inference_config=None, + volume_size=None, + model_data_download_timeout=None, + container_startup_health_check_timeout=None, **kwargs, ): """Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``. @@ -1092,6 +1095,16 @@ def deploy( empty object passed through, will use pre-defined values in ``ServerlessInferenceConfig`` class to deploy serverless endpoint. Deploy an instance based endpoint if it's None. (default: None) + volume_size (int): The size, in GB, of the ML storage volume attached to individual + inference instance associated with the production variant. Currenly only Amazon EBS + gp2 storage volumes are supported. + model_data_download_timeout (int): The timeout value, in seconds, to download and + extract model data from Amazon S3 to the individual inference instance associated + with this production variant. + container_startup_health_check_timeout (int): The timeout value, in seconds, for your + inference container to pass health check by SageMaker Hosting. For more information + about health check see: + https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-algo-ping-requests Raises: ValueError: If arguments combination check failed in these circumstances: - If no role is specified or @@ -1155,6 +1168,9 @@ def deploy( initial_instance_count, accelerator_type=accelerator_type, serverless_inference_config=serverless_inference_config_dict, + volume_size=volume_size, + model_data_download_timeout=model_data_download_timeout, + container_startup_health_check_timeout=container_startup_health_check_timeout, ) if endpoint_name: self.endpoint_name = endpoint_name diff --git a/src/sagemaker/pipeline.py b/src/sagemaker/pipeline.py index 5a293d0aec..ad5ed1291c 100644 --- a/src/sagemaker/pipeline.py +++ b/src/sagemaker/pipeline.py @@ -122,6 +122,9 @@ def deploy( update_endpoint=False, data_capture_config=None, kms_key=None, + volume_size=None, + model_data_download_timeout=None, + container_startup_health_check_timeout=None, ): """Deploy the ``Model`` to an ``Endpoint``. @@ -170,6 +173,16 @@ def deploy( kms_key (str): The ARN, Key ID or Alias of the KMS key that is used to encrypt the data on the storage volume attached to the instance hosting the endpoint. + volume_size (int): The size, in GB, of the ML storage volume attached to individual + inference instance associated with the production variant. Currenly only Amazon EBS + gp2 storage volumes are supported. + model_data_download_timeout (int): The timeout value, in seconds, to download and + extract model data from Amazon S3 to the individual inference instance associated + with this production variant. + container_startup_health_check_timeout (int): The timeout value, in seconds, for your + inference container to pass health check by SageMaker Hosting. For more information + about health check see: + https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-algo-ping-requests Returns: callable[string, sagemaker.session.Session] or None: Invocation of @@ -191,7 +204,12 @@ def deploy( ) production_variant = sagemaker.production_variant( - self.name, instance_type, initial_instance_count + self.name, + instance_type, + initial_instance_count, + volume_size=volume_size, + model_data_download_timeout=model_data_download_timeout, + container_startup_health_check_timeout=container_startup_health_check_timeout, ) self.endpoint_name = endpoint_name or self.name @@ -208,6 +226,9 @@ def deploy( tags=tags, kms_key=kms_key, data_capture_config_dict=data_capture_config_dict, + volume_size=volume_size, + model_data_download_timeout=model_data_download_timeout, + container_startup_health_check_timeout=container_startup_health_check_timeout, ) self.sagemaker_session.update_endpoint( self.endpoint_name, endpoint_config_name, wait=wait diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 733263ce0b..36b8d76e7b 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -2981,6 +2981,9 @@ def create_endpoint_config( tags=None, kms_key=None, data_capture_config_dict=None, + volume_size=None, + model_data_download_timeout=None, + container_startup_health_check_timeout=None, ): """Create an Amazon SageMaker endpoint configuration. @@ -3004,6 +3007,16 @@ def create_endpoint_config( attached to the instance hosting the endpoint. data_capture_config_dict (dict): Specifies configuration related to Endpoint data capture for use with Amazon SageMaker Model Monitoring. Default: None. + volume_size (int): The size, in GB, of the ML storage volume attached to individual + inference instance associated with the production variant. Currenly only Amazon EBS + gp2 storage volumes are supported. + model_data_download_timeout (int): The timeout value, in seconds, to download and + extract model data from Amazon S3 to the individual inference instance associated + with this production variant. + container_startup_health_check_timeout (int): The timeout value, in seconds, for your + inference container to pass health check by SageMaker Hosting. For more information + about health check see: + https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-algo-ping-requests Example: >>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}] @@ -3025,6 +3038,9 @@ def create_endpoint_config( instance_type, initial_instance_count, accelerator_type=accelerator_type, + volume_size=volume_size, + model_data_download_timeout=model_data_download_timeout, + container_startup_health_check_timeout=container_startup_health_check_timeout, ) ], } @@ -4636,6 +4652,9 @@ def production_variant( initial_weight=1, accelerator_type=None, serverless_inference_config=None, + volume_size=None, + model_data_download_timeout=None, + container_startup_health_check_timeout=None, ): """Create a production variant description suitable for use in a ``ProductionVariant`` list. @@ -4657,7 +4676,16 @@ def production_variant( serverless_inference_config (dict): Specifies configuration dict related to serverless endpoint. The dict is converted from sagemaker.model_monitor.ServerlessInferenceConfig object (default: None) - + volume_size (int): The size, in GB, of the ML storage volume attached to individual + inference instance associated with the production variant. Currenly only Amazon EBS + gp2 storage volumes are supported. + model_data_download_timeout (int): The timeout value, in seconds, to download and extract + model data from Amazon S3 to the individual inference instance associated with this + production variant. + container_startup_health_check_timeout (int): The timeout value, in seconds, for your + inference container to pass health check by SageMaker Hosting. For more information + about health check see: + https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-algo-ping-requests Returns: dict[str, str]: An SageMaker ``ProductionVariant`` description """ @@ -4676,6 +4704,12 @@ def production_variant( initial_instance_count = initial_instance_count or 1 production_variant_configuration["InitialInstanceCount"] = initial_instance_count production_variant_configuration["InstanceType"] = instance_type + update_args( + production_variant_configuration, + VolumeSizeInGB=volume_size, + ModelDataDownloadTimeoutInSeconds=model_data_download_timeout, + ContainerStartupHealthCheckTimeoutInSeconds=container_startup_health_check_timeout, + ) return production_variant_configuration diff --git a/src/sagemaker/tensorflow/model.py b/src/sagemaker/tensorflow/model.py index 05109102f2..4a93e4e99e 100644 --- a/src/sagemaker/tensorflow/model.py +++ b/src/sagemaker/tensorflow/model.py @@ -320,6 +320,9 @@ def deploy( update_endpoint=None, async_inference_config=None, serverless_inference_config=None, + volume_size=None, + model_data_download_timeout=None, + container_startup_health_check_timeout=None, ): """Deploy a Tensorflow ``Model`` to a SageMaker ``Endpoint``.""" @@ -340,6 +343,9 @@ def deploy( data_capture_config=data_capture_config, async_inference_config=async_inference_config, serverless_inference_config=serverless_inference_config, + volume_size=volume_size, + model_data_download_timeout=model_data_download_timeout, + container_startup_health_check_timeout=container_startup_health_check_timeout, update_endpoint=update_endpoint, ) diff --git a/tests/unit/sagemaker/automl/test_auto_ml.py b/tests/unit/sagemaker/automl/test_auto_ml.py index 1b25fef693..e3618da87a 100644 --- a/tests/unit/sagemaker/automl/test_auto_ml.py +++ b/tests/unit/sagemaker/automl/test_auto_ml.py @@ -596,6 +596,9 @@ def test_deploy_optional_args(candidate_estimator, sagemaker_session, candidate_ deserializer=None, endpoint_name=JOB_NAME, kms_key=OUTPUT_KMS_KEY, + volume_size=None, + model_data_download_timeout=None, + container_startup_health_check_timeout=None, tags=TAGS, wait=False, ) diff --git a/tests/unit/sagemaker/model/test_deploy.py b/tests/unit/sagemaker/model/test_deploy.py index af8e1a9e0e..7d60806726 100644 --- a/tests/unit/sagemaker/model/test_deploy.py +++ b/tests/unit/sagemaker/model/test_deploy.py @@ -71,6 +71,9 @@ def test_deploy(name_from_base, prepare_container_def, production_variant, sagem INSTANCE_COUNT, accelerator_type=None, serverless_inference_config=None, + volume_size=None, + model_data_download_timeout=None, + container_startup_health_check_timeout=None, ) sagemaker_session.create_model.assert_called_with( @@ -120,6 +123,9 @@ def test_deploy_accelerator_type( INSTANCE_COUNT, accelerator_type=ACCELERATOR_TYPE, serverless_inference_config=None, + volume_size=None, + model_data_download_timeout=None, + container_startup_health_check_timeout=None, ) sagemaker_session.endpoint_from_production_variants.assert_called_with( @@ -363,6 +369,9 @@ def test_deploy_serverless_inference(production_variant, create_sagemaker_model, None, accelerator_type=None, serverless_inference_config=serverless_inference_config_dict, + volume_size=None, + model_data_download_timeout=None, + container_startup_health_check_timeout=None, ) sagemaker_session.endpoint_from_production_variants.assert_called_with( @@ -474,3 +483,72 @@ def test_deploy_predictor_cls(production_variant, sagemaker_session): assert predictor_async.name == model.name assert predictor_async.endpoint_name == endpoint_name_async assert predictor_async.sagemaker_session == sagemaker_session + + +@patch("sagemaker.production_variant") +@patch("sagemaker.model.Model.prepare_container_def") +@patch("sagemaker.utils.name_from_base", return_value=MODEL_NAME) +def test_deploy_customized_volume_size_and_timeout( + name_from_base, prepare_container_def, production_variant, sagemaker_session +): + volume_size_gb = 256 + model_data_download_timeout_sec = 1800 + startup_health_check_timeout_sec = 1800 + + production_variant_result = copy.deepcopy(BASE_PRODUCTION_VARIANT) + production_variant_result.update( + { + "VolumeSizeInGB": volume_size_gb, + "ModelDataDownloadTimeoutInSeconds": model_data_download_timeout_sec, + "ContainerStartupHealthCheckTimeoutInSeconds": startup_health_check_timeout_sec, + } + ) + production_variant.return_value = production_variant_result + + container_def = {"Image": MODEL_IMAGE, "Environment": {}, "ModelDataUrl": MODEL_DATA} + prepare_container_def.return_value = container_def + + model = Model(MODEL_IMAGE, MODEL_DATA, role=ROLE, sagemaker_session=sagemaker_session) + model.deploy( + instance_type=INSTANCE_TYPE, + initial_instance_count=INSTANCE_COUNT, + volume_size=volume_size_gb, + model_data_download_timeout=model_data_download_timeout_sec, + container_startup_health_check_timeout=startup_health_check_timeout_sec, + ) + + name_from_base.assert_called_with(MODEL_IMAGE) + assert 2 == name_from_base.call_count + + prepare_container_def.assert_called_with( + INSTANCE_TYPE, accelerator_type=None, serverless_inference_config=None + ) + production_variant.assert_called_with( + MODEL_NAME, + INSTANCE_TYPE, + INSTANCE_COUNT, + accelerator_type=None, + serverless_inference_config=None, + volume_size=volume_size_gb, + model_data_download_timeout=model_data_download_timeout_sec, + container_startup_health_check_timeout=startup_health_check_timeout_sec, + ) + + sagemaker_session.create_model.assert_called_with( + name=MODEL_NAME, + role=ROLE, + container_defs=container_def, + vpc_config=None, + enable_network_isolation=False, + tags=None, + ) + + sagemaker_session.endpoint_from_production_variants.assert_called_with( + name=MODEL_NAME, + production_variants=[production_variant_result], + tags=None, + kms_key=None, + wait=True, + data_capture_config_dict=None, + async_inference_config_dict=None, + ) diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index cd63fb0678..34e6a43fcf 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -3173,6 +3173,9 @@ def test_generic_to_deploy_kms(create_model, sagemaker_session): data_capture_config=None, async_inference_config=None, serverless_inference_config=None, + volume_size=None, + model_data_download_timeout=None, + container_startup_health_check_timeout=None, ) @@ -3313,6 +3316,54 @@ def test_deploy_with_no_model_name(sagemaker_session): assert kwargs["name"].startswith(IMAGE_URI) +@patch("sagemaker.estimator.Estimator.create_model") +def test_deploy_with_customized_volume_size_timeout(create_model, sagemaker_session): + estimator = Estimator( + IMAGE_URI, + ROLE, + INSTANCE_COUNT, + INSTANCE_TYPE, + output_path=OUTPUT_PATH, + sagemaker_session=sagemaker_session, + ) + estimator.set_hyperparameters(**HYPERPARAMS) + estimator.fit({"train": "s3://bucket/training-prefix"}) + endpoint_name = "endpoint-name" + volume_size_gb = 256 + model_data_download_timeout_sec = 600 + startup_health_check_timeout_sec = 600 + + model = MagicMock() + create_model.return_value = model + + estimator.deploy( + INSTANCE_COUNT, + INSTANCE_TYPE, + endpoint_name=endpoint_name, + volume_size=volume_size_gb, + model_data_download_timeout=model_data_download_timeout_sec, + container_startup_health_check_timeout=startup_health_check_timeout_sec, + ) + + model.deploy.assert_called_with( + instance_type=INSTANCE_TYPE, + initial_instance_count=INSTANCE_COUNT, + serializer=None, + deserializer=None, + accelerator_type=None, + endpoint_name=endpoint_name, + tags=None, + wait=True, + kms_key=None, + data_capture_config=None, + async_inference_config=None, + serverless_inference_config=None, + volume_size=volume_size_gb, + model_data_download_timeout=model_data_download_timeout_sec, + container_startup_health_check_timeout=startup_health_check_timeout_sec, + ) + + def test_register_default_image(sagemaker_session): estimator = Estimator( IMAGE_URI, diff --git a/tests/unit/test_pipeline_model.py b/tests/unit/test_pipeline_model.py index 9dbbbb4ad6..f4fb892d21 100644 --- a/tests/unit/test_pipeline_model.py +++ b/tests/unit/test_pipeline_model.py @@ -188,6 +188,9 @@ def test_deploy_update_endpoint(tfo, time, sagemaker_session): tags=None, kms_key=None, data_capture_config_dict=None, + volume_size=None, + model_data_download_timeout=None, + container_startup_health_check_timeout=None, ) config_name = sagemaker_session.create_endpoint_config( name=model.name,