Skip to content

Commit c981604

Browse files
committed
feature: add support for serverless inference
1 parent 127c964 commit c981604

12 files changed

+373
-28
lines changed

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def read_version():
3434
# Declare minimal set for installation
3535
required_packages = [
3636
"attrs",
37-
"boto3>=1.20.18",
37+
"boto3>=1.20.21",
3838
"google-pasta",
3939
"numpy>=1.9.0",
4040
"protobuf>=3.1",

src/sagemaker/estimator.py

+19-7
Original file line numberDiff line numberDiff line change
@@ -852,8 +852,8 @@ def logs(self):
852852

853853
def deploy(
854854
self,
855-
initial_instance_count,
856-
instance_type,
855+
initial_instance_count=None,
856+
instance_type=None,
857857
serializer=None,
858858
deserializer=None,
859859
accelerator_type=None,
@@ -864,6 +864,7 @@ def deploy(
864864
kms_key=None,
865865
data_capture_config=None,
866866
tags=None,
867+
serverless_inference_config=None,
867868
**kwargs,
868869
):
869870
"""Deploy the trained model to an Amazon SageMaker endpoint.
@@ -874,10 +875,14 @@ def deploy(
874875
http://docs.aws.amazon.com/sagemaker/latest/dg/how-it-works-training.html
875876
876877
Args:
877-
initial_instance_count (int): Minimum number of EC2 instances to
878-
deploy to an endpoint for prediction.
879-
instance_type (str): Type of EC2 instance to deploy to an endpoint
880-
for prediction, for example, 'ml.c4.xlarge'.
878+
initial_instance_count (int): The initial number of instances to run
879+
in the ``Endpoint`` created from this ``Model``. If not using
880+
serverless inference, then it need to be a number larger or equals
881+
to 1 (default: None)
882+
instance_type (str): The EC2 instance type to deploy this Model to.
883+
For example, 'ml.p2.xlarge', or 'local' for local mode. If not using
884+
serverless inference, then it is required to deploy a model.
885+
(default: None)
881886
serializer (:class:`~sagemaker.serializers.BaseSerializer`): A
882887
serializer object, used to encode data for an inference endpoint
883888
(default: None). If ``serializer`` is not None, then
@@ -910,6 +915,11 @@ def deploy(
910915
data_capture_config (sagemaker.model_monitor.DataCaptureConfig): Specifies
911916
configuration related to Endpoint data capture for use with
912917
Amazon SageMaker Model Monitoring. Default: None.
918+
serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
919+
Specifies configuration related to serverless endpoint. Use this configuration
920+
when trying to create serverless endpoint and make serverless inference. If
921+
empty object passed through, we will use pre-defined values in
922+
``ServerlessInferenceConfig`` class to deploy serverless endpoint (default: None)
913923
tags(List[dict[str, str]]): Optional. The list of tags to attach to this specific
914924
endpoint. Example:
915925
>>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}]
@@ -927,14 +937,15 @@ def deploy(
927937
endpoint and obtain inferences.
928938
"""
929939
removed_kwargs("update_endpoint", kwargs)
940+
is_serverless = serverless_inference_config is not None
930941
self._ensure_latest_training_job()
931942
self._ensure_base_job_name()
932943
default_name = name_from_base(self.base_job_name)
933944
endpoint_name = endpoint_name or default_name
934945
model_name = model_name or default_name
935946

936947
self.deploy_instance_type = instance_type
937-
if use_compiled_model:
948+
if use_compiled_model and not is_serverless:
938949
family = "_".join(instance_type.split(".")[:-1])
939950
if family not in self._compiled_models:
940951
raise ValueError(
@@ -959,6 +970,7 @@ def deploy(
959970
wait=wait,
960971
kms_key=kms_key,
961972
data_capture_config=data_capture_config,
973+
serverless_inference_config=serverless_inference_config,
962974
)
963975

964976
def register(

src/sagemaker/model.py

+48-12
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from sagemaker.inputs import CompilationInput
3333
from sagemaker.deprecations import removed_kwargs
3434
from sagemaker.predictor import PredictorBase
35+
from sagemaker.serverless import ServerlessInferenceConfig
3536
from sagemaker.transformer import Transformer
3637

3738
LOGGER = logging.getLogger("sagemaker")
@@ -209,7 +210,7 @@ def register(
209210
model_package_arn=model_package.get("ModelPackageArn"),
210211
)
211212

212-
def _init_sagemaker_session_if_does_not_exist(self, instance_type):
213+
def _init_sagemaker_session_if_does_not_exist(self, instance_type=None):
213214
"""Set ``self.sagemaker_session`` to ``LocalSession`` or ``Session`` if it's not already.
214215
215216
The type of session object is determined by the instance type.
@@ -688,8 +689,8 @@ def compile(
688689

689690
def deploy(
690691
self,
691-
initial_instance_count,
692-
instance_type,
692+
initial_instance_count=None,
693+
instance_type=None,
693694
serializer=None,
694695
deserializer=None,
695696
accelerator_type=None,
@@ -698,6 +699,7 @@ def deploy(
698699
kms_key=None,
699700
wait=True,
700701
data_capture_config=None,
702+
serverless_inference_config=None,
701703
**kwargs,
702704
):
703705
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.
@@ -715,9 +717,13 @@ def deploy(
715717
716718
Args:
717719
initial_instance_count (int): The initial number of instances to run
718-
in the ``Endpoint`` created from this ``Model``.
720+
in the ``Endpoint`` created from this ``Model``. If not using
721+
serverless inference, then it need to be a number larger or equals
722+
to 1 (default: None)
719723
instance_type (str): The EC2 instance type to deploy this Model to.
720-
For example, 'ml.p2.xlarge', or 'local' for local mode.
724+
For example, 'ml.p2.xlarge', or 'local' for local mode. If not using
725+
serverless inference, then it is required to deploy a model.
726+
(default: None)
721727
serializer (:class:`~sagemaker.serializers.BaseSerializer`): A
722728
serializer object, used to encode data for an inference endpoint
723729
(default: None). If ``serializer`` is not None, then
@@ -746,7 +752,17 @@ def deploy(
746752
data_capture_config (sagemaker.model_monitor.DataCaptureConfig): Specifies
747753
configuration related to Endpoint data capture for use with
748754
Amazon SageMaker Model Monitoring. Default: None.
749-
755+
serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
756+
Specifies configuration related to serverless endpoint. Use this configuration
757+
when trying to create serverless endpoint and make serverless inference. If
758+
empty object passed through, we will use pre-defined values in
759+
``ServerlessInferenceConfig`` class to deploy serverless endpoint (default: None)
760+
Raises:
761+
ValueError: If arguments combination check failed in these circumstances:
762+
- If no role is specified or
763+
- If serverless inference config is not specified and instance type and instance
764+
count are also not specified or
765+
- If a wrong type of object is provided as serverless inference config
750766
Returns:
751767
callable[string, sagemaker.session.Session] or None: Invocation of
752768
``self.predictor_cls`` on the created endpoint name, if ``self.predictor_cls``
@@ -758,27 +774,47 @@ def deploy(
758774
if self.role is None:
759775
raise ValueError("Role can not be null for deploying a model")
760776

761-
if instance_type.startswith("ml.inf") and not self._is_compiled_model:
777+
is_serverless = serverless_inference_config is not None
778+
if not is_serverless and not (instance_type and initial_instance_count):
779+
raise ValueError(
780+
"Must specify instance type and instance count unless using serverless inference"
781+
)
782+
783+
if is_serverless and not isinstance(serverless_inference_config, ServerlessInferenceConfig):
784+
raise ValueError(
785+
"serverless_inference_config needs to be a ServerlessInferenceConfig object"
786+
)
787+
788+
if instance_type and instance_type.startswith("ml.inf") and not self._is_compiled_model:
762789
LOGGER.warning(
763790
"Your model is not compiled. Please compile your model before using Inferentia."
764791
)
765792

766-
compiled_model_suffix = "-".join(instance_type.split(".")[:-1])
767-
if self._is_compiled_model:
793+
compiled_model_suffix = None if is_serverless else "-".join(instance_type.split(".")[:-1])
794+
if self._is_compiled_model and not is_serverless:
768795
self._ensure_base_name_if_needed(self.image_uri)
769796
if self._base_name is not None:
770797
self._base_name = "-".join((self._base_name, compiled_model_suffix))
771798

772799
self._create_sagemaker_model(instance_type, accelerator_type, tags)
800+
801+
serverless_inference_config_dict = (
802+
serverless_inference_config._to_request_dict() if is_serverless else None
803+
)
773804
production_variant = sagemaker.production_variant(
774-
self.name, instance_type, initial_instance_count, accelerator_type=accelerator_type
805+
self.name,
806+
instance_type,
807+
initial_instance_count,
808+
accelerator_type=accelerator_type,
809+
serverless_inference_config=serverless_inference_config_dict,
775810
)
776811
if endpoint_name:
777812
self.endpoint_name = endpoint_name
778813
else:
779814
base_endpoint_name = self._base_name or utils.base_from_name(self.name)
780-
if self._is_compiled_model and not base_endpoint_name.endswith(compiled_model_suffix):
781-
base_endpoint_name = "-".join((base_endpoint_name, compiled_model_suffix))
815+
if self._is_compiled_model and not is_serverless:
816+
if not base_endpoint_name.endswith(compiled_model_suffix):
817+
base_endpoint_name = "-".join((base_endpoint_name, compiled_model_suffix))
782818
self.endpoint_name = utils.name_from_base(base_endpoint_name)
783819

784820
data_capture_config_dict = None

src/sagemaker/serverless/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,6 @@
1313
"""Classes for performing machine learning on serverless compute."""
1414
from sagemaker.serverless.model import LambdaModel # noqa: F401
1515
from sagemaker.serverless.predictor import LambdaPredictor # noqa: F401
16+
from sagemaker.serverless.serverless_inference_config import ( # noqa: F401
17+
ServerlessInferenceConfig,
18+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module contains code related to the ServerlessInferenceConfig class.
14+
15+
Codes are used for configuring async inference endpoint. Use it when deploying
16+
the model to the endpoints.
17+
"""
18+
from __future__ import print_function, absolute_import
19+
20+
21+
class ServerlessInferenceConfig(object):
22+
"""Configuration object passed in when deploying models to Amazon SageMaker Endpoints.
23+
24+
This object specifies configuration related to serverless endpoint. Use this configuration
25+
when trying to create serverless endpoint and make serverless inference
26+
"""
27+
28+
def __init__(
29+
self,
30+
memory_size_in_mb=2048,
31+
max_concurrency=5,
32+
):
33+
"""Initialize a ServerlessInferenceConfig object for serverless inference configuration.
34+
35+
Args:
36+
memory_size_in_mb (int): Optional. The memory size of your serverless endpoint.
37+
Valid values are in 1 GB increments: 1024 MB, 2048 MB, 3072 MB, 4096 MB,
38+
5120 MB, or 6144 MB. If no value is provided, Amazon SageMaker will choose
39+
the default value for you. (Default: 2048)
40+
max_concurrency (int): Optional. The maximum number of concurrent invocations
41+
your serverless endpoint can process. If no value is provided, Amazon
42+
SageMaker will choose the default value for you. (Default: 5)
43+
"""
44+
self.memory_size_in_mb = memory_size_in_mb
45+
self.max_concurrency = max_concurrency
46+
47+
def _to_request_dict(self):
48+
"""Generates a request dictionary using the parameters provided to the class."""
49+
request_dict = {
50+
"MemorySizeInMB": self.memory_size_in_mb,
51+
"MaxConcurrency": self.max_concurrency,
52+
}
53+
54+
return request_dict

src/sagemaker/session.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -4382,11 +4382,12 @@ def pipeline_container_def(models, instance_type=None):
43824382

43834383
def production_variant(
43844384
model_name,
4385-
instance_type,
4386-
initial_instance_count=1,
4385+
instance_type=None,
4386+
initial_instance_count=None,
43874387
variant_name="AllTraffic",
43884388
initial_weight=1,
43894389
accelerator_type=None,
4390+
serverless_inference_config=None,
43904391
):
43914392
"""Create a production variant description suitable for use in a ``ProductionVariant`` list.
43924393
@@ -4405,21 +4406,29 @@ def production_variant(
44054406
accelerator_type (str): Type of Elastic Inference accelerator for this production variant.
44064407
For example, 'ml.eia1.medium'.
44074408
For more information: https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html
4409+
serverless_inference_config (dict): Specifies configuration dict related to serverless
4410+
endpoint. The dict is converted from sagemaker.model_monitor.ServerlessInferenceConfig
4411+
object (default: None)
44084412
44094413
Returns:
44104414
dict[str, str]: An SageMaker ``ProductionVariant`` description
44114415
"""
44124416
production_variant_configuration = {
44134417
"ModelName": model_name,
4414-
"InstanceType": instance_type,
4415-
"InitialInstanceCount": initial_instance_count,
44164418
"VariantName": variant_name,
44174419
"InitialVariantWeight": initial_weight,
44184420
}
44194421

44204422
if accelerator_type:
44214423
production_variant_configuration["AcceleratorType"] = accelerator_type
44224424

4425+
if serverless_inference_config:
4426+
production_variant_configuration["ServerlessConfig"] = serverless_inference_config
4427+
else:
4428+
initial_instance_count = initial_instance_count or 1
4429+
production_variant_configuration["InitialInstanceCount"] = initial_instance_count
4430+
production_variant_configuration["InstanceType"] = instance_type
4431+
44234432
return production_variant_configuration
44244433

44254434

src/sagemaker/tensorflow/model.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,8 @@ def register(
258258

259259
def deploy(
260260
self,
261-
initial_instance_count,
262-
instance_type,
261+
initial_instance_count=None,
262+
instance_type=None,
263263
serializer=None,
264264
deserializer=None,
265265
accelerator_type=None,
@@ -269,6 +269,7 @@ def deploy(
269269
wait=True,
270270
data_capture_config=None,
271271
update_endpoint=None,
272+
serverless_inference_config=None,
272273
):
273274
"""Deploy a Tensorflow ``Model`` to a SageMaker ``Endpoint``."""
274275

@@ -287,6 +288,7 @@ def deploy(
287288
kms_key=kms_key,
288289
wait=wait,
289290
data_capture_config=data_capture_config,
291+
serverless_inference_config=serverless_inference_config,
290292
update_endpoint=update_endpoint,
291293
)
292294

0 commit comments

Comments
 (0)