Skip to content

Commit 65f4cc3

Browse files
author
Ashish Gupta
committed
add more tests
1 parent 4431dd8 commit 65f4cc3

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed

tests/unit/sagemaker/model/test_model.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -958,6 +958,54 @@ def test_all_framework_models_inference_component_based_endpoint_deploy_path(
958958
sagemaker_session.endpoint_in_service_or_not.reset_mock()
959959
sagemaker_session.create_model.reset_mock()
960960

961+
@patch("sagemaker.utils.repack_model")
962+
@patch("sagemaker.fw_utils.tar_and_upload_dir")
963+
def test_sharded_model_force_inference_component_based_endpoint_deploy_path(
964+
repack_model, tar_and_uload_dir, sagemaker_session
965+
):
966+
framework_model_classes_to_kwargs = {
967+
HuggingFaceModel: {
968+
"pytorch_version": "1.7.1",
969+
"py_version": "py36",
970+
"transformers_version": "4.6.1"
971+
},
972+
}
973+
974+
sagemaker_session.settings = SessionSettings(include_jumpstart_tags=False)
975+
976+
source_dir = "s3://blah/blah/blah"
977+
for framework_model_class, kwargs in framework_model_classes_to_kwargs.items():
978+
test_sharded_model = framework_model_class(
979+
entry_point=ENTRY_POINT_INFERENCE,
980+
role=ROLE,
981+
sagemaker_session=sagemaker_session,
982+
model_data=source_dir,
983+
**kwargs,
984+
)
985+
test_sharded_model._is_sharded_model = True
986+
test_sharded_model.deploy(
987+
instance_type="ml.m2.xlarge",
988+
initial_instance_count=INSTANCE_COUNT,
989+
endpoint_type=EndpointType.MODEL_BASED,
990+
resources=ResourceRequirements(
991+
requests={
992+
"num_accelerators": 1,
993+
"memory": 8192,
994+
"copies": 1,
995+
},
996+
limits={},
997+
),
998+
)
999+
1000+
# Verified inference component based endpoint and inference component creation
1001+
# path
1002+
sagemaker_session.endpoint_in_service_or_not.assert_called_once()
1003+
sagemaker_session.create_model.assert_called_once()
1004+
sagemaker_session.create_inference_component.assert_called_once()
1005+
1006+
sagemaker_session.create_inference_component.reset_mock()
1007+
sagemaker_session.endpoint_in_service_or_not.reset_mock()
1008+
sagemaker_session.create_model.reset_mock()
9611009

9621010
@patch("sagemaker.utils.repack_model")
9631011
def test_repack_code_location_with_key_prefix(repack_model, sagemaker_session):

0 commit comments

Comments
 (0)