@@ -209,7 +209,7 @@ def register(
209
209
model_package_arn = model_package .get ("ModelPackageArn" ),
210
210
)
211
211
212
- def _init_sagemaker_session_if_does_not_exist (self , instance_type ):
212
+ def _init_sagemaker_session_if_does_not_exist (self , instance_type = None ):
213
213
"""Set ``self.sagemaker_session`` to ``LocalSession`` or ``Session`` if it's not already.
214
214
215
215
The type of session object is determined by the instance type.
@@ -688,8 +688,8 @@ def compile(
688
688
689
689
def deploy (
690
690
self ,
691
- initial_instance_count ,
692
- instance_type ,
691
+ initial_instance_count = None ,
692
+ instance_type = None ,
693
693
serializer = None ,
694
694
deserializer = None ,
695
695
accelerator_type = None ,
@@ -698,6 +698,7 @@ def deploy(
698
698
kms_key = None ,
699
699
wait = True ,
700
700
data_capture_config = None ,
701
+ serverless_inference_config = None ,
701
702
** kwargs ,
702
703
):
703
704
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.
@@ -715,9 +716,13 @@ def deploy(
715
716
716
717
Args:
717
718
initial_instance_count (int): The initial number of instances to run
718
- in the ``Endpoint`` created from this ``Model``.
719
+ in the ``Endpoint`` created from this ``Model``. If not using
720
+ serverless inference, then it need to be a number larger or equals
721
+ to 1 (default: None)
719
722
instance_type (str): The EC2 instance type to deploy this Model to.
720
- For example, 'ml.p2.xlarge', or 'local' for local mode.
723
+ For example, 'ml.p2.xlarge', or 'local' for local mode. If not using
724
+ serverless inference, then it is required to deploy a model.
725
+ (default: None)
721
726
serializer (:class:`~sagemaker.serializers.BaseSerializer`): A
722
727
serializer object, used to encode data for an inference endpoint
723
728
(default: None). If ``serializer`` is not None, then
@@ -746,7 +751,14 @@ def deploy(
746
751
data_capture_config (sagemaker.model_monitor.DataCaptureConfig): Specifies
747
752
configuration related to Endpoint data capture for use with
748
753
Amazon SageMaker Model Monitoring. Default: None.
749
-
754
+ serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
755
+ Specifies configuration related to serverless endpoint. Use this configuration
756
+ when trying to create serverless endpoint and make serverless inference. If
757
+ empty config object passed through, we will use default config to deploy
758
+ serverless endpoint (default: None)
759
+ Raises:
760
+ ValueError: If no role is specified or if serverless inference config is not
761
+ specified and instance type and instance count are also not specified
750
762
Returns:
751
763
callable[string, sagemaker.session.Session] or None: Invocation of
752
764
``self.predictor_cls`` on the created endpoint name, if ``self.predictor_cls``
@@ -757,28 +769,43 @@ def deploy(
757
769
758
770
if self .role is None :
759
771
raise ValueError ("Role can not be null for deploying a model" )
772
+ is_serverless = bool (serverless_inference_config )
773
+ if not is_serverless and not (instance_type and initial_instance_count ):
774
+ raise ValueError (
775
+ "Must specify instance type and instance count unless using serverless inference"
776
+ )
760
777
761
- if instance_type .startswith ("ml.inf" ) and not self ._is_compiled_model :
778
+ if instance_type and instance_type .startswith ("ml.inf" ) and not self ._is_compiled_model :
762
779
LOGGER .warning (
763
780
"Your model is not compiled. Please compile your model before using Inferentia."
764
781
)
765
782
766
- compiled_model_suffix = "-" . join ( instance_type . split ( "." )[: - 1 ])
767
- if self . _is_compiled_model :
783
+ if self . _is_compiled_model and not is_serverless :
784
+ compiled_model_suffix = "-" . join ( instance_type . split ( "." )[: - 1 ])
768
785
self ._ensure_base_name_if_needed (self .image_uri )
769
786
if self ._base_name is not None :
770
787
self ._base_name = "-" .join ((self ._base_name , compiled_model_suffix ))
771
788
772
789
self ._create_sagemaker_model (instance_type , accelerator_type , tags )
790
+
791
+ serverless_inference_config_dict = (
792
+ serverless_inference_config ._to_request_dict () if is_serverless else None
793
+ )
773
794
production_variant = sagemaker .production_variant (
774
- self .name , instance_type , initial_instance_count , accelerator_type = accelerator_type
795
+ self .name ,
796
+ instance_type ,
797
+ initial_instance_count ,
798
+ accelerator_type = accelerator_type ,
799
+ serverless_inference_config = serverless_inference_config_dict ,
775
800
)
776
801
if endpoint_name :
777
802
self .endpoint_name = endpoint_name
778
803
else :
779
804
base_endpoint_name = self ._base_name or utils .base_from_name (self .name )
780
- if self ._is_compiled_model and not base_endpoint_name .endswith (compiled_model_suffix ):
781
- base_endpoint_name = "-" .join ((base_endpoint_name , compiled_model_suffix ))
805
+ if self ._is_compiled_model and not is_serverless :
806
+ compiled_model_suffix = "-" .join (instance_type .split ("." )[:- 1 ])
807
+ if not base_endpoint_name .endswith (compiled_model_suffix ):
808
+ base_endpoint_name = "-" .join ((base_endpoint_name , compiled_model_suffix ))
782
809
self .endpoint_name = utils .name_from_base (base_endpoint_name )
783
810
784
811
data_capture_config_dict = None
0 commit comments