diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index ff340b58e9..36b49f1ac1 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -156,6 +156,7 @@ def __init__( dependencies: Optional[List[str]] = None, git_config: Optional[Dict[str, str]] = None, resources: Optional[ResourceRequirements] = None, + task: Optional[Union[str, PipelineVariable]] = None, ): """Initialize an SageMaker ``Model``. @@ -319,7 +320,9 @@ def __init__( for a model to be deployed to an endpoint. Only EndpointType.INFERENCE_COMPONENT_BASED supports this feature. (Default: None). - + task (str or PipelineVariable): Task values used to override the HuggingFace task + Examples are: "audio-classification", "depth-estimation", + "feature-extraction" etc. (default: None). """ self.model_data = model_data self.image_uri = image_uri @@ -396,6 +399,7 @@ def __init__( self.content_types = None self.response_types = None self.accept_eula = None + self.task = task @runnable_by_pipeline def register( diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 0ade8096f6..d5b28e5803 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -203,6 +203,7 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers): "help": ( 'Model object with "predict" method to perform inference ' "or HuggingFace/JumpStart Model ID" + "or HuggingFace Task to override" ) }, ) diff --git a/tests/integ/sagemaker/serve/test_serve_transformers.py b/tests/integ/sagemaker/serve/test_serve_transformers.py index 735f60d0f2..958c469159 100644 --- a/tests/integ/sagemaker/serve/test_serve_transformers.py +++ b/tests/integ/sagemaker/serve/test_serve_transformers.py @@ -13,6 +13,7 @@ from __future__ import absolute_import import pytest +from sagemaker.model import Model from sagemaker.serve.builder.schema_builder import SchemaBuilder from sagemaker.serve.builder.model_builder import ModelBuilder, Mode @@ -28,6 +29,11 @@ logger = logging.getLogger(__name__) +MODEL_DATA = "s3://bucket/model.tar.gz" +MODEL_IMAGE = "mi" +ROLE = "some-role" +HF_TASK = "fill-mask" + sample_input = { "inputs": "The man worked as a [MASK].", } @@ -80,6 +86,17 @@ def model_builder_model_schema_builder(): ) +@pytest.fixture +def model_builder_model_with_task_builder(): + model = Model( + MODEL_IMAGE, MODEL_DATA, task=HF_TASK, name="bert-base-uncased", role=ROLE + ) + return ModelBuilder( + model_path=HF_DIR, + model=model, + ) + + @pytest.fixture def model_builder(request): return request.getfixturevalue(request.param) @@ -122,3 +139,42 @@ def test_pytorch_transformers_sagemaker_endpoint( assert ( False ), f"{caught_ex} was thrown when running pytorch transformers sagemaker endpoint test" + + +@pytest.mark.skipif( + PYTHON_VERSION_IS_NOT_310, + reason="Testing Optional task", +) +@pytest.mark.parametrize("model_builder", ["model_builder_model_with_task_builder"], indirect=True) +def test_happy_path_with_task_sagemaker_endpoint( + sagemaker_session, model_builder, gpu_instance_type, input +): + logger.info("Running in SAGEMAKER_ENDPOINT mode...") + caught_ex = None + + iam_client = sagemaker_session.boto_session.client("iam") + role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"] + + model = model_builder.build( + mode=Mode.SAGEMAKER_ENDPOINT, role_arn=role_arn, sagemaker_session=sagemaker_session + ) + + with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT): + try: + logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...") + predictor = model.deploy(instance_type=gpu_instance_type, initial_instance_count=1) + logger.info("Endpoint successfully deployed.") + predictor.predict(input) + except Exception as e: + caught_ex = e + finally: + cleanup_model_resources( + sagemaker_session=model_builder.sagemaker_session, + model_name=model.name, + endpoint_name=model.endpoint_name, + ) + if caught_ex: + logger.exception(caught_ex) + assert ( + False + ), f"{caught_ex} was thrown when running pytorch transformers sagemaker endpoint test" diff --git a/tests/unit/sagemaker/model/test_deploy.py b/tests/unit/sagemaker/model/test_deploy.py index 953cbe775c..ef2f0d4516 100644 --- a/tests/unit/sagemaker/model/test_deploy.py +++ b/tests/unit/sagemaker/model/test_deploy.py @@ -85,6 +85,7 @@ }, limits={}, ) +HF_TASK = "audio-classification" @pytest.fixture @@ -1027,3 +1028,57 @@ def test_deploy_with_name_and_resources(sagemaker_session): async_inference_config_dict=None, live_logging=False, ) + + +@patch("sagemaker.model.Model._create_sagemaker_model", Mock()) +@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT) +def test_deploy_with_name_and_task(sagemaker_session): + sagemaker_session.sagemaker_config = {} + + model = Model( + MODEL_IMAGE, MODEL_DATA, task=HF_TASK, name=MODEL_NAME, role=ROLE, sagemaker_session=sagemaker_session + ) + + endpoint_name = "testing-task-input" + predictor = model.deploy( + endpoint_name=endpoint_name, + instance_type=INSTANCE_TYPE, + initial_instance_count=INSTANCE_COUNT, + ) + + sagemaker_session.create_model.assert_called_with( + name=MODEL_IMAGE, + role=ROLE, + task=HF_TASK + ) + + assert isinstance(predictor, sagemaker.predictor.Predictor) + assert predictor.endpoint_name == endpoint_name + assert predictor.sagemaker_session == sagemaker_session + + +@patch("sagemaker.model.Model._create_sagemaker_model", Mock()) +@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT) +def test_deploy_with_name_and_without_task(sagemaker_session): + sagemaker_session.sagemaker_config = {} + + model = Model( + MODEL_IMAGE, MODEL_DATA, name=MODEL_NAME, role=ROLE, sagemaker_session=sagemaker_session + ) + + endpoint_name = "testing-without-task-input" + predictor = model.deploy( + endpoint_name=endpoint_name, + instance_type=INSTANCE_TYPE, + initial_instance_count=INSTANCE_COUNT, + ) + + sagemaker_session.create_model.assert_called_with( + name=MODEL_IMAGE, + role=ROLE, + task=None, + ) + + assert isinstance(predictor, sagemaker.predictor.Predictor) + assert predictor.endpoint_name == endpoint_name + assert predictor.sagemaker_session == sagemaker_session