32
32
from sagemaker .inputs import CompilationInput
33
33
from sagemaker .deprecations import removed_kwargs
34
34
from sagemaker .predictor import PredictorBase
35
+ from sagemaker .serverless import ServerlessInferenceConfig
35
36
from sagemaker .transformer import Transformer
36
37
37
38
LOGGER = logging .getLogger ("sagemaker" )
@@ -209,7 +210,7 @@ def register(
209
210
model_package_arn = model_package .get ("ModelPackageArn" ),
210
211
)
211
212
212
- def _init_sagemaker_session_if_does_not_exist (self , instance_type ):
213
+ def _init_sagemaker_session_if_does_not_exist (self , instance_type = None ):
213
214
"""Set ``self.sagemaker_session`` to ``LocalSession`` or ``Session`` if it's not already.
214
215
215
216
The type of session object is determined by the instance type.
@@ -688,8 +689,8 @@ def compile(
688
689
689
690
def deploy (
690
691
self ,
691
- initial_instance_count ,
692
- instance_type ,
692
+ initial_instance_count = None ,
693
+ instance_type = None ,
693
694
serializer = None ,
694
695
deserializer = None ,
695
696
accelerator_type = None ,
@@ -698,6 +699,7 @@ def deploy(
698
699
kms_key = None ,
699
700
wait = True ,
700
701
data_capture_config = None ,
702
+ serverless_inference_config = None ,
701
703
** kwargs ,
702
704
):
703
705
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.
@@ -715,9 +717,13 @@ def deploy(
715
717
716
718
Args:
717
719
initial_instance_count (int): The initial number of instances to run
718
- in the ``Endpoint`` created from this ``Model``.
720
+ in the ``Endpoint`` created from this ``Model``. If not using
721
+ serverless inference, then it need to be a number larger or equals
722
+ to 1 (default: None)
719
723
instance_type (str): The EC2 instance type to deploy this Model to.
720
- For example, 'ml.p2.xlarge', or 'local' for local mode.
724
+ For example, 'ml.p2.xlarge', or 'local' for local mode. If not using
725
+ serverless inference, then it is required to deploy a model.
726
+ (default: None)
721
727
serializer (:class:`~sagemaker.serializers.BaseSerializer`): A
722
728
serializer object, used to encode data for an inference endpoint
723
729
(default: None). If ``serializer`` is not None, then
@@ -746,7 +752,17 @@ def deploy(
746
752
data_capture_config (sagemaker.model_monitor.DataCaptureConfig): Specifies
747
753
configuration related to Endpoint data capture for use with
748
754
Amazon SageMaker Model Monitoring. Default: None.
749
-
755
+ serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
756
+ Specifies configuration related to serverless endpoint. Use this configuration
757
+ when trying to create serverless endpoint and make serverless inference. If
758
+ empty object passed through, we will use pre-defined values in
759
+ ``ServerlessInferenceConfig`` class to deploy serverless endpoint (default: None)
760
+ Raises:
761
+ ValueError: If arguments combination check failed in these circumstances:
762
+ - If no role is specified or
763
+ - If serverless inference config is not specified and instance type and instance
764
+ count are also not specified or
765
+ - If a wrong type of object is provided as serverless inference config
750
766
Returns:
751
767
callable[string, sagemaker.session.Session] or None: Invocation of
752
768
``self.predictor_cls`` on the created endpoint name, if ``self.predictor_cls``
@@ -758,27 +774,47 @@ def deploy(
758
774
if self .role is None :
759
775
raise ValueError ("Role can not be null for deploying a model" )
760
776
761
- if instance_type .startswith ("ml.inf" ) and not self ._is_compiled_model :
777
+ is_serverless = serverless_inference_config is not None
778
+ if not is_serverless and not (instance_type and initial_instance_count ):
779
+ raise ValueError (
780
+ "Must specify instance type and instance count unless using serverless inference"
781
+ )
782
+
783
+ if is_serverless and not isinstance (serverless_inference_config , ServerlessInferenceConfig ):
784
+ raise ValueError (
785
+ "serverless_inference_config needs to be a ServerlessInferenceConfig object"
786
+ )
787
+
788
+ if instance_type and instance_type .startswith ("ml.inf" ) and not self ._is_compiled_model :
762
789
LOGGER .warning (
763
790
"Your model is not compiled. Please compile your model before using Inferentia."
764
791
)
765
792
766
- compiled_model_suffix = "-" .join (instance_type .split ("." )[:- 1 ])
767
- if self ._is_compiled_model :
793
+ compiled_model_suffix = None if is_serverless else "-" .join (instance_type .split ("." )[:- 1 ])
794
+ if self ._is_compiled_model and not is_serverless :
768
795
self ._ensure_base_name_if_needed (self .image_uri )
769
796
if self ._base_name is not None :
770
797
self ._base_name = "-" .join ((self ._base_name , compiled_model_suffix ))
771
798
772
799
self ._create_sagemaker_model (instance_type , accelerator_type , tags )
800
+
801
+ serverless_inference_config_dict = (
802
+ serverless_inference_config ._to_request_dict () if is_serverless else None
803
+ )
773
804
production_variant = sagemaker .production_variant (
774
- self .name , instance_type , initial_instance_count , accelerator_type = accelerator_type
805
+ self .name ,
806
+ instance_type ,
807
+ initial_instance_count ,
808
+ accelerator_type = accelerator_type ,
809
+ serverless_inference_config = serverless_inference_config_dict ,
775
810
)
776
811
if endpoint_name :
777
812
self .endpoint_name = endpoint_name
778
813
else :
779
814
base_endpoint_name = self ._base_name or utils .base_from_name (self .name )
780
- if self ._is_compiled_model and not base_endpoint_name .endswith (compiled_model_suffix ):
781
- base_endpoint_name = "-" .join ((base_endpoint_name , compiled_model_suffix ))
815
+ if self ._is_compiled_model and not is_serverless :
816
+ if not base_endpoint_name .endswith (compiled_model_suffix ):
817
+ base_endpoint_name = "-" .join ((base_endpoint_name , compiled_model_suffix ))
782
818
self .endpoint_name = utils .name_from_base (base_endpoint_name )
783
819
784
820
data_capture_config_dict = None
0 commit comments