Skip to content

Commit c5c8f3f

Browse files
bhaozknikure
authored andcommitted
fix: fix HuggingFace GEN2 model deployment arguments (#1404)
1 parent 363edc8 commit c5c8f3f

File tree

3 files changed

+120
-0
lines changed

3 files changed

+120
-0
lines changed

src/sagemaker/huggingface/model.py

+3
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,9 @@ def deploy(
331331
inference_recommendation_id=inference_recommendation_id,
332332
explainer_config=explainer_config,
333333
endpoint_logging=kwargs.get("endpoint_logging", False),
334+
endpoint_type=kwargs.get("endpoint_type", None),
335+
resources=kwargs.get("resources", None),
336+
managed_instance_scaling=kwargs.get("managed_instance_scaling", None),
334337
)
335338

336339
def register(

tests/integ/test_huggingface.py

+56
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from sagemaker.utils import unique_name_from_base
2222
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
2323
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
24+
from sagemaker.enums import EndpointType
25+
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
2426

2527
ROLE = "SageMakerRole"
2628

@@ -172,3 +174,57 @@ def test_huggingface_inference(
172174
}
173175
output = predictor.predict(data)
174176
assert "score" in output[0]
177+
178+
179+
@pytest.mark.skip(
180+
reason="re-enable when above GEN1 endpoint hugging face inference test enabled",
181+
)
182+
def test_huggingface_inference_gen2_endpoint(
183+
sagemaker_session,
184+
gpu_pytorch_instance_type,
185+
huggingface_inference_latest_version,
186+
huggingface_inference_pytorch_latest_version,
187+
huggingface_pytorch_latest_inference_py_version,
188+
):
189+
env = {
190+
"HF_MODEL_ID": "philschmid/tiny-distilbert-classification",
191+
"HF_TASK": "text-classification",
192+
}
193+
endpoint_name = unique_name_from_base("test-hf-inference")
194+
195+
model = HuggingFaceModel(
196+
sagemaker_session=sagemaker_session,
197+
role="SageMakerRole",
198+
env=env,
199+
py_version=huggingface_pytorch_latest_inference_py_version,
200+
transformers_version=huggingface_inference_latest_version,
201+
pytorch_version=huggingface_inference_pytorch_latest_version,
202+
)
203+
predictor = model.deploy(
204+
instance_type=gpu_pytorch_instance_type,
205+
initial_instance_count=1,
206+
endpoint_name=endpoint_name,
207+
endpoint_type=EndpointType.GEN2,
208+
resources=ResourceRequirements(
209+
requests={
210+
"num_accelerators": 1, # NumberOfCpuCoresRequired
211+
"memory": 8192, # MinMemoryRequiredInMb (required)
212+
"copies": 1,
213+
},
214+
limits={},
215+
),
216+
)
217+
218+
data = {
219+
"inputs": "Camera - You are awarded a SiPix Digital Camera!"
220+
"call 09061221066 fromm landline. Delivery within 28 days."
221+
}
222+
223+
output = predictor.predict(data)
224+
assert "score" in output[0]
225+
226+
# delete predictor
227+
predictor.delete_predictor(wait=True)
228+
229+
# delete endpoint
230+
predictor.delete_endpoint()

tests/unit/sagemaker/model/test_model.py

+61
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
from sagemaker.sklearn.model import SKLearnModel
3333
from sagemaker.tensorflow.model import TensorFlowModel
3434
from sagemaker.xgboost.model import XGBoostModel
35+
from sagemaker.enums import EndpointType
36+
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
3537
from sagemaker.workflow.properties import Properties
3638
from tests.unit import (
3739
_test_default_bucket_and_prefix_combinations,
@@ -854,6 +856,65 @@ def test_script_mode_model_uses_jumpstart_base_name(repack_model, sagemaker_sess
854856
].startswith(JUMPSTART_RESOURCE_BASE_NAME)
855857

856858

859+
@patch("sagemaker.utils.repack_model")
860+
@patch("sagemaker.fw_utils.tar_and_upload_dir")
861+
def test_all_framework_models_generation_two_endpoint_deploy_path(
862+
repack_model, tar_and_uload_dir, sagemaker_session
863+
):
864+
framework_model_classes_to_kwargs = {
865+
PyTorchModel: {"framework_version": "1.5.0", "py_version": "py3"},
866+
TensorFlowModel: {
867+
"framework_version": "2.3",
868+
},
869+
HuggingFaceModel: {
870+
"pytorch_version": "1.7.1",
871+
"py_version": "py36",
872+
"transformers_version": "4.6.1",
873+
},
874+
MXNetModel: {"framework_version": "1.7.0", "py_version": "py3"},
875+
SKLearnModel: {
876+
"framework_version": "0.23-1",
877+
},
878+
XGBoostModel: {
879+
"framework_version": "1.3-1",
880+
},
881+
}
882+
883+
sagemaker_session.settings = SessionSettings(include_jumpstart_tags=False)
884+
885+
source_dir = "s3://blah/blah/blah"
886+
for framework_model_class, kwargs in framework_model_classes_to_kwargs.items():
887+
framework_model_class(
888+
entry_point=ENTRY_POINT_INFERENCE,
889+
role=ROLE,
890+
sagemaker_session=sagemaker_session,
891+
model_data=source_dir,
892+
**kwargs,
893+
).deploy(
894+
instance_type="ml.m2.xlarge",
895+
initial_instance_count=INSTANCE_COUNT,
896+
endpoint_type=EndpointType.GEN2,
897+
resources=ResourceRequirements(
898+
requests={
899+
"num_accelerators": 1,
900+
"memory": 8192,
901+
"copies": 1,
902+
},
903+
limits={},
904+
),
905+
)
906+
907+
# Verified Generation2 endpoint and inference component creation
908+
# path
909+
sagemaker_session.endpoint_in_service_or_not.assert_called_once()
910+
sagemaker_session.create_model.assert_called_once()
911+
sagemaker_session.create_inference_component.assert_called_once()
912+
913+
sagemaker_session.create_inference_component.reset_mock()
914+
sagemaker_session.endpoint_in_service_or_not.reset_mock()
915+
sagemaker_session.create_model.reset_mock()
916+
917+
857918
@patch("sagemaker.utils.repack_model")
858919
def test_repack_code_location_with_key_prefix(repack_model, sagemaker_session):
859920

0 commit comments

Comments
 (0)