diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 14724e207d..0d5e8605e2 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -34,14 +34,19 @@ JUMPSTART_DEFAULT_REGION_NAME, JUMPSTART_LOGGER, ) +from sagemaker.model_metrics import ModelMetrics +from sagemaker.metadata_properties import MetadataProperties +from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.jumpstart.types import ( JumpStartModelDeployKwargs, JumpStartModelInitKwargs, + JumpStartModelRegisterKwargs, ) from sagemaker.jumpstart.utils import ( update_dict_if_key_not_present, resolve_model_sagemaker_config_field, + verify_model_region_and_return_specs, ) from sagemaker.model_monitor.data_capture_config import DataCaptureConfig @@ -507,6 +512,87 @@ def get_deploy_kwargs( return deploy_kwargs +def get_register_kwargs( + model_id: str, + model_version: Optional[str] = None, + region: Optional[str] = None, + tolerate_deprecated_model: Optional[bool] = None, + tolerate_vulnerable_model: Optional[bool] = None, + sagemaker_session: Optional[Any] = None, + supported_content_types: List[str] = None, + response_types: List[str] = None, + inference_instances: Optional[List[str]] = None, + transform_instances: Optional[List[str]] = None, + model_package_group_name: Optional[str] = None, + image_uri: Optional[str] = None, + model_metrics: Optional[ModelMetrics] = None, + metadata_properties: Optional[MetadataProperties] = None, + approval_status: Optional[str] = None, + description: Optional[str] = None, + drift_check_baselines: Optional[DriftCheckBaselines] = None, + customer_metadata_properties: Optional[Dict[str, str]] = None, + validation_specification: Optional[str] = None, + domain: Optional[str] = None, + task: Optional[str] = None, + sample_payload_url: Optional[str] = None, + framework: Optional[str] = None, + framework_version: Optional[str] = None, + nearest_model_name: Optional[str] = None, + data_input_configuration: Optional[str] = None, + skip_model_validation: Optional[str] = None, +) -> JumpStartModelRegisterKwargs: + """Returns kwargs required to call `register` on `sagemaker.estimator.Model` object.""" + + register_kwargs = JumpStartModelRegisterKwargs( + model_id=model_id, + model_version=model_version, + region=region, + tolerate_deprecated_model=tolerate_deprecated_model, + tolerate_vulnerable_model=tolerate_vulnerable_model, + sagemaker_session=sagemaker_session, + content_types=supported_content_types, + response_types=response_types, + inference_instances=inference_instances, + transform_instances=transform_instances, + model_package_group_name=model_package_group_name, + image_uri=image_uri, + model_metrics=model_metrics, + metadata_properties=metadata_properties, + approval_status=approval_status, + description=description, + drift_check_baselines=drift_check_baselines, + customer_metadata_properties=customer_metadata_properties, + validation_specification=validation_specification, + domain=domain, + task=task, + sample_payload_url=sample_payload_url, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, + skip_model_validation=skip_model_validation, + ) + + model_specs = verify_model_region_and_return_specs( + model_id=model_id, + version=model_version, + region=region, + scope=JumpStartScriptScope.INFERENCE, + sagemaker_session=sagemaker_session, + tolerate_deprecated_model=tolerate_deprecated_model, + tolerate_vulnerable_model=tolerate_vulnerable_model, + ) + + register_kwargs.content_types = ( + register_kwargs.content_types or model_specs.predictor_specs.supported_content_types + ) + register_kwargs.response_types = ( + register_kwargs.response_types or model_specs.predictor_specs.supported_accept_types + ) + + return register_kwargs + + def get_init_kwargs( model_id: str, model_from_estimator: bool = False, diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 95a4bb3b99..8c246da20b 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -13,7 +13,6 @@ """This module stores JumpStart implementation of Model class.""" from __future__ import absolute_import -import re from typing import Dict, List, Optional, Union from sagemaker import payloads @@ -28,16 +27,23 @@ get_default_predictor, get_deploy_kwargs, get_init_kwargs, + get_register_kwargs, ) from sagemaker.jumpstart.types import JumpStartSerializablePayload from sagemaker.jumpstart.utils import is_valid_model_id from sagemaker.utils import stringify_object -from sagemaker.model import MODEL_PACKAGE_ARN_PATTERN, Model +from sagemaker.model import ( + Model, + ModelPackage, +) from sagemaker.model_monitor.data_capture_config import DataCaptureConfig from sagemaker.predictor import PredictorBase from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig from sagemaker.session import Session from sagemaker.workflow.entities import PipelineVariable +from sagemaker.model_metrics import ModelMetrics +from sagemaker.metadata_properties import MetadataProperties +from sagemaker.drift_check_baselines import DriftCheckBaselines class JumpStartModel(Model): @@ -309,11 +315,12 @@ def _is_valid_model_id_hook(): self.tolerate_vulnerable_model = model_init_kwargs.tolerate_vulnerable_model self.tolerate_deprecated_model = model_init_kwargs.tolerate_deprecated_model self.region = model_init_kwargs.region - self.model_package_arn = model_init_kwargs.model_package_arn self.sagemaker_session = model_init_kwargs.sagemaker_session super(JumpStartModel, self).__init__(**model_init_kwargs.to_kwargs_dict()) + self.model_package_arn = model_init_kwargs.model_package_arn + def retrieve_all_examples(self) -> Optional[List[JumpStartSerializablePayload]]: """Returns all example payloads associated with the model. @@ -390,30 +397,29 @@ def _create_sagemaker_model( # inference endpoint. if self.model_package_arn and not self._model_data_is_set: # When a ModelPackageArn is provided we just create the Model - match = re.match(MODEL_PACKAGE_ARN_PATTERN, self.model_package_arn) - if match: - model_package_name = match.group(3) - else: - # model_package_arn can be just the name if your account owns the Model Package - model_package_name = self.model_package_arn - container_def = {"ModelPackageName": self.model_package_arn} - - if self.env != {}: - container_def["Environment"] = self.env - - if self.name is None: - self._base_name = model_package_name - - self._set_model_name_if_needed() - - self.sagemaker_session.create_model( - self.name, - self.role, - container_def, + model_package = ModelPackage( + role=self.role, + model_data=self.model_data, + model_package_arn=self.model_package_arn, + sagemaker_session=self.sagemaker_session, + predictor_cls=self.predictor_cls, vpc_config=self.vpc_config, - enable_network_isolation=self.enable_network_isolation(), + ) + if self.name is not None: + model_package.name = self.name + if self.env is not None: + model_package.env = self.env + model_package._create_sagemaker_model( + instance_type=instance_type, + accelerator_type=accelerator_type, tags=tags, + serverless_inference_config=serverless_inference_config, + **kwargs, ) + if self._base_name is None and model_package._base_name is not None: + self._base_name = model_package._base_name + if self.name is None and model_package.name is not None: + self.name = model_package.name else: super(JumpStartModel, self)._create_sagemaker_model( instance_type=instance_type, @@ -565,6 +571,124 @@ def deploy( # If a predictor class was passed, do not mutate predictor return predictor + def register( + self, + content_types: List[Union[str, PipelineVariable]] = None, + response_types: List[Union[str, PipelineVariable]] = None, + inference_instances: Optional[List[Union[str, PipelineVariable]]] = None, + transform_instances: Optional[List[Union[str, PipelineVariable]]] = None, + model_package_group_name: Optional[Union[str, PipelineVariable]] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + model_metrics: Optional[ModelMetrics] = None, + metadata_properties: Optional[MetadataProperties] = None, + approval_status: Optional[Union[str, PipelineVariable]] = None, + description: Optional[str] = None, + drift_check_baselines: Optional[DriftCheckBaselines] = None, + customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + validation_specification: Optional[Union[str, PipelineVariable]] = None, + domain: Optional[Union[str, PipelineVariable]] = None, + task: Optional[Union[str, PipelineVariable]] = None, + sample_payload_url: Optional[Union[str, PipelineVariable]] = None, + framework: Optional[Union[str, PipelineVariable]] = None, + framework_version: Optional[Union[str, PipelineVariable]] = None, + nearest_model_name: Optional[Union[str, PipelineVariable]] = None, + data_input_configuration: Optional[Union[str, PipelineVariable]] = None, + skip_model_validation: Optional[Union[str, PipelineVariable]] = None, + ): + """Creates a model package for creating SageMaker models or listing on Marketplace. + + Args: + content_types (list[str] or list[PipelineVariable]): The supported MIME types + for the input data. + response_types (list[str] or list[PipelineVariable]): The supported MIME types + for the output data. + inference_instances (list[str] or list[PipelineVariable]): A list of the instance + types that are used to generate inferences in real-time (default: None). + transform_instances (list[str] or list[PipelineVariable]): A list of the instance types + on which a transformation job can be run or on which an endpoint can be deployed + (default: None). + model_package_group_name (str or PipelineVariable): Model Package Group name, + exclusive to `model_package_name`, using `model_package_group_name` makes the + Model Package versioned. Defaults to ``None``. + image_uri (str or PipelineVariable): Inference image URI for the container. Model class' + self.image will be used if it is None. Defaults to ``None``. + model_metrics (ModelMetrics): ModelMetrics object. Defaults to ``None``. + metadata_properties (MetadataProperties): MetadataProperties object. + Defaults to ``None``. + approval_status (str or PipelineVariable): Model Approval Status, values can be + "Approved", "Rejected", or "PendingManualApproval". Defaults to + ``PendingManualApproval``. + description (str): Model Package description. Defaults to ``None``. + drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str] or dict[str, PipelineVariable]): + A dictionary of key-value paired metadata properties (default: None). + domain (str or PipelineVariable): Domain values can be "COMPUTER_VISION", + "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). + sample_payload_url (str or PipelineVariable): The S3 path where the sample payload + is stored (default: None). + task (str or PipelineVariable): Task values which are supported by Inference Recommender + are "FILL_MASK", "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", + "IMAGE_SEGMENTATION", "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). + framework (str or PipelineVariable): Machine learning framework of the model package + container image (default: None). + framework_version (str or PipelineVariable): Framework version of the Model Package + Container Image (default: None). + nearest_model_name (str or PipelineVariable): Name of a pre-trained machine learning + benchmarked by Amazon SageMaker Inference Recommender (default: None). + data_input_configuration (str or PipelineVariable): Input object for the model + (default: None). + skip_model_validation (str or PipelineVariable): Indicates if you want to skip model + validation. Values can be "All" or "None" (default: None). + + Returns: + A `sagemaker.model.ModelPackage` instance. + """ + + register_kwargs = get_register_kwargs( + model_id=self.model_id, + model_version=self.model_version, + region=self.region, + tolerate_deprecated_model=self.tolerate_deprecated_model, + tolerate_vulnerable_model=self.tolerate_vulnerable_model, + sagemaker_session=self.sagemaker_session, + supported_content_types=content_types, + response_types=response_types, + inference_instances=inference_instances, + transform_instances=transform_instances, + model_package_group_name=model_package_group_name, + image_uri=image_uri, + model_metrics=model_metrics, + metadata_properties=metadata_properties, + approval_status=approval_status, + description=description, + drift_check_baselines=drift_check_baselines, + customer_metadata_properties=customer_metadata_properties, + validation_specification=validation_specification, + domain=domain, + task=task, + sample_payload_url=sample_payload_url, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, + skip_model_validation=skip_model_validation, + ) + + model_package = super(JumpStartModel, self).register(**register_kwargs.to_kwargs_dict()) + + def register_deploy_wrapper(*args, **kwargs): + if self.model_package_arn is not None: + return self.deploy(*args, **kwargs) + + self.model_package_arn = model_package.model_package_arn + predictor = self.deploy(*args, **kwargs) + self.model_package_arn = None + return predictor + + model_package.deploy = register_deploy_wrapper + + return model_package + def __str__(self) -> str: """Overriding str(*) method to make more human-readable.""" return stringify_object(self) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 4f5a8489f0..666ca23e87 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -16,6 +16,9 @@ from enum import Enum from typing import Any, Dict, List, Optional, Set, Union from sagemaker.utils import get_instance_type_family +from sagemaker.model_metrics import ModelMetrics +from sagemaker.metadata_properties import MetadataProperties +from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.session import Session from sagemaker.workflow.entities import PipelineVariable @@ -1486,3 +1489,107 @@ def __init__( self.tolerate_deprecated_model = tolerate_deprecated_model self.tolerate_vulnerable_model = tolerate_vulnerable_model self.use_compiled_model = use_compiled_model + + +class JumpStartModelRegisterKwargs(JumpStartKwargs): + """Data class for the inputs to `JumpStartEstimator.deploy` method.""" + + __slots__ = [ + "tolerate_vulnerable_model", + "tolerate_deprecated_model", + "region", + "model_id", + "model_version", + "sagemaker_session", + "content_types", + "response_types", + "inference_instances", + "transform_instances", + "model_package_group_name", + "image_uri", + "model_metrics", + "metadata_properties", + "approval_status", + "description", + "drift_check_baselines", + "customer_metadata_properties", + "validation_specification", + "domain", + "task", + "sample_payload_url", + "framework", + "framework_version", + "nearest_model_name", + "data_input_configuration", + "skip_model_validation", + ] + + SERIALIZATION_EXCLUSION_SET = { + "tolerate_vulnerable_model", + "tolerate_deprecated_model", + "region", + "model_id", + "model_version", + "sagemaker_session", + } + + def __init__( + self, + model_id: str, + model_version: Optional[str] = None, + region: Optional[str] = None, + tolerate_deprecated_model: Optional[bool] = None, + tolerate_vulnerable_model: Optional[bool] = None, + sagemaker_session: Optional[Any] = None, + content_types: List[str] = None, + response_types: List[str] = None, + inference_instances: Optional[List[str]] = None, + transform_instances: Optional[List[str]] = None, + model_package_group_name: Optional[str] = None, + image_uri: Optional[str] = None, + model_metrics: Optional[ModelMetrics] = None, + metadata_properties: Optional[MetadataProperties] = None, + approval_status: Optional[str] = None, + description: Optional[str] = None, + drift_check_baselines: Optional[DriftCheckBaselines] = None, + customer_metadata_properties: Optional[Dict[str, str]] = None, + validation_specification: Optional[str] = None, + domain: Optional[str] = None, + task: Optional[str] = None, + sample_payload_url: Optional[str] = None, + framework: Optional[str] = None, + framework_version: Optional[str] = None, + nearest_model_name: Optional[str] = None, + data_input_configuration: Optional[str] = None, + skip_model_validation: Optional[str] = None, + ) -> None: + """Instantiates JumpStartModelRegisterKwargs object.""" + + self.model_id = model_id + self.model_version = model_version + self.region = region + self.image_uri = image_uri + self.sagemaker_session = sagemaker_session + self.tolerate_deprecated_model = tolerate_deprecated_model + self.tolerate_vulnerable_model = tolerate_vulnerable_model + self.content_types = content_types + self.response_types = response_types + self.inference_instances = inference_instances + self.transform_instances = transform_instances + self.model_package_group_name = model_package_group_name + self.image_uri = image_uri + self.model_metrics = model_metrics + self.metadata_properties = metadata_properties + self.approval_status = approval_status + self.description = description + self.drift_check_baselines = drift_check_baselines + self.customer_metadata_properties = customer_metadata_properties + self.validation_specification = validation_specification + self.domain = domain + self.task = task + self.sample_payload_url = sample_payload_url + self.framework = framework + self.framework_version = framework_version + self.nearest_model_name = nearest_model_name + self.data_input_configuration = data_input_configuration + self.skip_model_validation = skip_model_validation diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index dfaa802779..528d022946 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -315,6 +315,8 @@ def __init__( self.name = name self._base_name = None self.sagemaker_session = sagemaker_session + self.algorithm_arn = None + self.model_package_arn = None # Workaround for config injection if sagemaker_session is None, since in # that case sagemaker_session will not be initialized until @@ -534,6 +536,7 @@ def register( model_data=self.model_data, model_package_arn=model_package.get("ModelPackageArn"), sagemaker_session=self.sagemaker_session, + predictor_cls=self.predictor_cls, ) @runnable_by_pipeline @@ -792,61 +795,87 @@ def _create_sagemaker_model( Specifies configuration related to serverless endpoint. Instance type is not provided in serverless inference. So this is used to find image URIs. """ - container_def = self.prepare_container_def( - instance_type, - accelerator_type=accelerator_type, - serverless_inference_config=serverless_inference_config, - ) - if not isinstance(self.sagemaker_session, PipelineSession): - # _base_name, model_name are not needed under PipelineSession. - # the model_data may be Pipeline variable - # which may break the _base_name generation - model_uri = None - if isinstance(self.model_data, (str, PipelineVariable)): - model_uri = self.model_data - elif isinstance(self.model_data, dict): - model_uri = self.model_data.get("S3DataSource", {}).get("S3Uri", None) - - self._ensure_base_name_if_needed( - image_uri=container_def["Image"], - script_uri=self.source_dir, - model_uri=model_uri, + if self.model_package_arn is not None or self.algorithm_arn is not None: + model_package = ModelPackage( + role=self.role, + model_data=self.model_data, + model_package_arn=self.model_package_arn, + algorithm_arn=self.algorithm_arn, + sagemaker_session=self.sagemaker_session, + predictor_cls=self.predictor_cls, + vpc_config=self.vpc_config, + ) + if self.name is not None: + model_package.name = self.name + if self.env is not None: + model_package.env = self.env + model_package._create_sagemaker_model( + instance_type=instance_type, + accelerator_type=accelerator_type, + tags=tags, + serverless_inference_config=serverless_inference_config, + ) + if self._base_name is None and model_package._base_name is not None: + self._base_name = model_package._base_name + if self.name is None and model_package.name is not None: + self.name = model_package.name + else: + container_def = self.prepare_container_def( + instance_type, + accelerator_type=accelerator_type, + serverless_inference_config=serverless_inference_config, ) - self._set_model_name_if_needed() - self._init_sagemaker_session_if_does_not_exist(instance_type) - # Depending on the instance type, a local session (or) a session is initialized. - self.role = resolve_value_from_config( - self.role, - MODEL_EXECUTION_ROLE_ARN_PATH, - sagemaker_session=self.sagemaker_session, - ) - self.vpc_config = resolve_value_from_config( - self.vpc_config, - MODEL_VPC_CONFIG_PATH, - sagemaker_session=self.sagemaker_session, - ) - self._enable_network_isolation = resolve_value_from_config( - self._enable_network_isolation, - MODEL_ENABLE_NETWORK_ISOLATION_PATH, - sagemaker_session=self.sagemaker_session, - ) - self.env = resolve_nested_dict_value_from_config( - self.env, - ["Environment"], - MODEL_CONTAINERS_PATH, - sagemaker_session=self.sagemaker_session, - ) - create_model_args = dict( - name=self.name, - role=self.role, - container_defs=container_def, - vpc_config=self.vpc_config, - enable_network_isolation=self._enable_network_isolation, - tags=tags, - ) - self.sagemaker_session.create_model(**create_model_args) + if not isinstance(self.sagemaker_session, PipelineSession): + # _base_name, model_name are not needed under PipelineSession. + # the model_data may be Pipeline variable + # which may break the _base_name generation + model_uri = None + if isinstance(self.model_data, (str, PipelineVariable)): + model_uri = self.model_data + elif isinstance(self.model_data, dict): + model_uri = self.model_data.get("S3DataSource", {}).get("S3Uri", None) + + self._ensure_base_name_if_needed( + image_uri=container_def["Image"], + script_uri=self.source_dir, + model_uri=model_uri, + ) + self._set_model_name_if_needed() + + self._init_sagemaker_session_if_does_not_exist(instance_type) + # Depending on the instance type, a local session (or) a session is initialized. + self.role = resolve_value_from_config( + self.role, + MODEL_EXECUTION_ROLE_ARN_PATH, + sagemaker_session=self.sagemaker_session, + ) + self.vpc_config = resolve_value_from_config( + self.vpc_config, + MODEL_VPC_CONFIG_PATH, + sagemaker_session=self.sagemaker_session, + ) + self._enable_network_isolation = resolve_value_from_config( + self._enable_network_isolation, + MODEL_ENABLE_NETWORK_ISOLATION_PATH, + sagemaker_session=self.sagemaker_session, + ) + self.env = resolve_nested_dict_value_from_config( + self.env, + ["Environment"], + MODEL_CONTAINERS_PATH, + sagemaker_session=self.sagemaker_session, + ) + create_model_args = dict( + name=self.name, + role=self.role, + container_defs=container_def, + vpc_config=self.vpc_config, + enable_network_isolation=self._enable_network_isolation, + tags=tags, + ) + self.sagemaker_session.create_model(**create_model_args) def _ensure_base_name_if_needed(self, image_uri, script_uri, model_uri): """Create a base name from the image URI if there is no model name provided. @@ -1897,6 +1926,7 @@ def _create_sagemaker_model(self, *args, **kwargs): # pylint: disable=unused-ar kwargs: Keyword arguments coming from the caller. This class does not require any so they are ignored. """ + if self.algorithm_arn: # When ModelPackage is created using an algorithm_arn we need to first # create a ModelPackage. If we had already created one then its fine to re-use it. diff --git a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py index bf39805897..9843b17c41 100644 --- a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py +++ b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py @@ -30,7 +30,6 @@ get_tabular_data, ) - MAX_INIT_TIME_SECONDS = 5 GATED_INFERENCE_MODEL_SUPPORTED_REGIONS = { @@ -130,3 +129,26 @@ def test_instatiating_model_not_too_slow(setup): elapsed_time = time.perf_counter() - start_time assert elapsed_time <= MAX_INIT_TIME_SECONDS + + +def test_jumpstart_model_register(setup): + model_id = "huggingface-txt2img-conflictx-complex-lineart" + + model = JumpStartModel( + model_id=model_id, + role=get_sm_session().get_caller_identity_arn(), + sagemaker_session=get_sm_session(), + ) + + model_package = model.register() + + # uses instance + predictor = model_package.deploy( + instance_type="ml.p3.2xlarge", + initial_instance_count=1, + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], + ) + + response = predictor.predict("hello world!") + + assert response is not None diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index 5527929b03..bb8e0cb389 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -853,6 +853,38 @@ def test_model_artifact_variant_model( enable_network_isolation=True, ) + @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.model.Model.register") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_registry_accept_and_response_types( + self, + mock_model_register: mock.Mock, + mock_model_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_is_valid_model_id: mock.Mock, + ): + mock_model_deploy.return_value = default_predictor + + mock_is_valid_model_id.return_value = True + model_id, _ = "model_data_s3_prefix_model", "*" + + mock_get_model_specs.side_effect = get_special_model_spec + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id, instance_type="ml.p2.xlarge") + + model.register() + + mock_model_register.assert_called_once_with( + content_types=["application/x-text"], + response_types=["application/json;verbose", "application/json"], + ) + def test_jumpstart_model_requires_model_id(): with pytest.raises(ValueError):