@@ -958,6 +958,54 @@ def test_all_framework_models_inference_component_based_endpoint_deploy_path(
958
958
sagemaker_session .endpoint_in_service_or_not .reset_mock ()
959
959
sagemaker_session .create_model .reset_mock ()
960
960
961
+ @patch ("sagemaker.utils.repack_model" )
962
+ @patch ("sagemaker.fw_utils.tar_and_upload_dir" )
963
+ def test_sharded_model_force_inference_component_based_endpoint_deploy_path (
964
+ repack_model , tar_and_uload_dir , sagemaker_session
965
+ ):
966
+ framework_model_classes_to_kwargs = {
967
+ HuggingFaceModel : {
968
+ "pytorch_version" : "1.7.1" ,
969
+ "py_version" : "py36" ,
970
+ "transformers_version" : "4.6.1"
971
+ },
972
+ }
973
+
974
+ sagemaker_session .settings = SessionSettings (include_jumpstart_tags = False )
975
+
976
+ source_dir = "s3://blah/blah/blah"
977
+ for framework_model_class , kwargs in framework_model_classes_to_kwargs .items ():
978
+ test_sharded_model = framework_model_class (
979
+ entry_point = ENTRY_POINT_INFERENCE ,
980
+ role = ROLE ,
981
+ sagemaker_session = sagemaker_session ,
982
+ model_data = source_dir ,
983
+ ** kwargs ,
984
+ )
985
+ test_sharded_model ._is_sharded_model = True
986
+ test_sharded_model .deploy (
987
+ instance_type = "ml.m2.xlarge" ,
988
+ initial_instance_count = INSTANCE_COUNT ,
989
+ endpoint_type = EndpointType .MODEL_BASED ,
990
+ resources = ResourceRequirements (
991
+ requests = {
992
+ "num_accelerators" : 1 ,
993
+ "memory" : 8192 ,
994
+ "copies" : 1 ,
995
+ },
996
+ limits = {},
997
+ ),
998
+ )
999
+
1000
+ # Verified inference component based endpoint and inference component creation
1001
+ # path
1002
+ sagemaker_session .endpoint_in_service_or_not .assert_called_once ()
1003
+ sagemaker_session .create_model .assert_called_once ()
1004
+ sagemaker_session .create_inference_component .assert_called_once ()
1005
+
1006
+ sagemaker_session .create_inference_component .reset_mock ()
1007
+ sagemaker_session .endpoint_in_service_or_not .reset_mock ()
1008
+ sagemaker_session .create_model .reset_mock ()
961
1009
962
1010
@patch ("sagemaker.utils.repack_model" )
963
1011
def test_repack_code_location_with_key_prefix (repack_model , sagemaker_session ):
0 commit comments