From 427f60cab32b540d793c8416f8aef0af8d6bae05 Mon Sep 17 00:00:00 2001 From: Clayton Parnell <42805768+claytonparnell@users.noreply.github.com> Date: Tue, 7 Mar 2023 17:22:03 -0500 Subject: [PATCH] Revert "feature: Decouple model.right_size() from model registry (#3688)" This reverts commit 1ad9d9cd57db3ebfb4313aafca0df0dc135d0597. --- .../inference_recommender_mixin.py | 38 +--- src/sagemaker/session.py | 18 +- tests/integ/test_inference_recommender.py | 181 ------------------ .../test_inference_recommender_mixin.py | 161 ++-------------- tests/unit/test_session.py | 127 ------------ 5 files changed, 23 insertions(+), 502 deletions(-) diff --git a/src/sagemaker/inference_recommender/inference_recommender_mixin.py b/src/sagemaker/inference_recommender/inference_recommender_mixin.py index 9ea05dd4a9..de40162c0f 100644 --- a/src/sagemaker/inference_recommender/inference_recommender_mixin.py +++ b/src/sagemaker/inference_recommender/inference_recommender_mixin.py @@ -38,7 +38,7 @@ class Phase: """ def __init__(self, duration_in_seconds: int, initial_number_of_users: int, spawn_rate: int): - """Initialize a `Phase`""" + """Initialze a `Phase`""" self.to_json = { "DurationInSeconds": duration_in_seconds, "InitialNumberOfUsers": initial_number_of_users, @@ -53,7 +53,7 @@ class ModelLatencyThreshold: """ def __init__(self, percentile: str, value_in_milliseconds: int): - """Initialize a `ModelLatencyThreshold`""" + """Initialze a `ModelLatencyThreshold`""" self.to_json = {"Percentile": percentile, "ValueInMilliseconds": value_in_milliseconds} @@ -79,12 +79,6 @@ def right_size( ): """Recommends an instance type for a SageMaker or BYOC model. - Create a SageMaker ``Model`` or use a registered ``ModelPackage``, - to start an Inference Recommender job. - - The name of the created model is accessible in the ``name`` field of - this ``Model`` after right_size returns. - Args: sample_payload_url (str): The S3 path where the sample payload is stored. supported_content_types: (list[str]): The supported MIME types for the input data. @@ -125,6 +119,8 @@ def right_size( sagemaker.model.Model: A SageMaker ``Model`` object. See :func:`~sagemaker.model.Model` for full details. """ + if not isinstance(self, sagemaker.model.ModelPackage): + raise ValueError("right_size() is currently only supported with a registered model") if not framework and self._framework(): framework = INFERENCE_RECOMMENDER_FRAMEWORK_MAPPING.get(self._framework(), framework) @@ -153,36 +149,12 @@ def right_size( self._init_sagemaker_session_if_does_not_exist() - if isinstance(self, sagemaker.model.Model) and not isinstance( - self, sagemaker.model.ModelPackage - ): - primary_container_def = self.prepare_container_def() - if not self.name: - self._ensure_base_name_if_needed( - image_uri=primary_container_def["Image"], - script_uri=self.source_dir, - model_uri=self.model_data, - ) - self._set_model_name_if_needed() - - create_model_args = dict( - name=self.name, - role=self.role, - container_defs=None, - primary_container=primary_container_def, - vpc_config=self.vpc_config, - enable_network_isolation=self.enable_network_isolation(), - ) - LOGGER.warning("Attempting to create new model with name %s", self.name) - self.sagemaker_session.create_model(**create_model_args) - ret_name = self.sagemaker_session.create_inference_recommendations_job( role=self.role, job_name=job_name, job_type=job_type, job_duration_in_seconds=job_duration_in_seconds, - model_name=self.name, - model_package_version_arn=getattr(self, "model_package_arn", None), + model_package_version_arn=self.model_package_arn, framework=framework, framework_version=framework_version, sample_payload_url=sample_payload_url, diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index d7e587de7c..2f5191bc30 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -4820,7 +4820,6 @@ def _create_inference_recommendations_job_request( framework: str, sample_payload_url: str, supported_content_types: List[str], - model_name: str = None, model_package_version_arn: str = None, job_duration_in_seconds: int = None, job_type: str = "Default", @@ -4844,7 +4843,6 @@ def _create_inference_recommendations_job_request( framework (str): The machine learning framework of the Image URI. sample_payload_url (str): The S3 path where the sample payload is stored. supported_content_types (List[str]): The supported MIME types for the input data. - model_name (str): Name of the Amazon SageMaker ``Model`` to be used. model_package_version_arn (str): The Amazon Resource Name (ARN) of a versioned model package. job_duration_in_seconds (int): The maximum job duration that a job @@ -4892,15 +4890,10 @@ def _create_inference_recommendations_job_request( "RoleArn": role, "InputConfig": { "ContainerConfig": containerConfig, + "ModelPackageVersionArn": model_package_version_arn, }, } - request.get("InputConfig").update( - {"ModelPackageVersionArn": model_package_version_arn} - if model_package_version_arn - else {"ModelName": model_name} - ) - if job_description: request["JobDescription"] = job_description if job_duration_in_seconds: @@ -4925,7 +4918,6 @@ def create_inference_recommendations_job( supported_content_types: List[str], job_name: str = None, job_type: str = "Default", - model_name: str = None, model_package_version_arn: str = None, job_duration_in_seconds: int = None, nearest_model_name: str = None, @@ -4946,7 +4938,6 @@ def create_inference_recommendations_job( You must grant sufficient permissions to this role. sample_payload_url (str): The S3 path where the sample payload is stored. supported_content_types (List[str]): The supported MIME types for the input data. - model_name (str): Name of the Amazon SageMaker ``Model`` to be used. model_package_version_arn (str): The Amazon Resource Name (ARN) of a versioned model package. job_name (str): The name of the job being run. @@ -4973,12 +4964,6 @@ def create_inference_recommendations_job( str: The name of the job created. In the form of `SMPYTHONSDK-` """ - if model_name is None and model_package_version_arn is None: - raise ValueError("Please provide either model_name or model_package_version_arn.") - - if model_name is not None and model_package_version_arn is not None: - raise ValueError("Please provide either model_name or model_package_version_arn.") - if not job_name: unique_tail = uuid.uuid4() job_name = "SMPYTHONSDK-" + str(unique_tail) @@ -4987,7 +4972,6 @@ def create_inference_recommendations_job( create_inference_recommendations_job_request = ( self._create_inference_recommendations_job_request( role=role, - model_name=model_name, model_package_version_arn=model_package_version_arn, job_name=job_name, job_type=job_type, diff --git a/tests/integ/test_inference_recommender.py b/tests/integ/test_inference_recommender.py index 77594ca94c..d363e5a00f 100644 --- a/tests/integ/test_inference_recommender.py +++ b/tests/integ/test_inference_recommender.py @@ -16,7 +16,6 @@ import pytest -from sagemaker.model import Model from sagemaker.sklearn.model import SKLearnModel, SKLearnPredictor from sagemaker.utils import unique_name_from_base from tests.integ import DATA_DIR @@ -155,120 +154,6 @@ def advanced_right_sized_model(sagemaker_session, cpu_instance_type): ) -@pytest.fixture(scope="module") -def default_right_sized_unregistered_model(sagemaker_session, cpu_instance_type): - with timeout(minutes=45): - try: - ir_job_name = unique_name_from_base("test-ir-right-size-job-name") - model_data = sagemaker_session.upload_data(path=IR_SKLEARN_MODEL) - payload_data = sagemaker_session.upload_data(path=IR_SKLEARN_PAYLOAD) - - iam_client = sagemaker_session.boto_session.client("iam") - role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"] - - sklearn_model = SKLearnModel( - model_data=model_data, - role=role_arn, - entry_point=IR_SKLEARN_ENTRY_POINT, - framework_version=IR_SKLEARN_FRAMEWORK_VERSION, - ) - - return ( - sklearn_model.right_size( - job_name=ir_job_name, - sample_payload_url=payload_data, - supported_content_types=IR_SKLEARN_CONTENT_TYPE, - supported_instance_types=[cpu_instance_type], - framework=IR_SKLEARN_FRAMEWORK, - log_level="Quiet", - ), - ir_job_name, - ) - except Exception: - sagemaker_session.delete_model(ModelName=sklearn_model.name) - - -@pytest.fixture(scope="module") -def advanced_right_sized_unregistered_model(sagemaker_session, cpu_instance_type): - with timeout(minutes=45): - try: - model_data = sagemaker_session.upload_data(path=IR_SKLEARN_MODEL) - payload_data = sagemaker_session.upload_data(path=IR_SKLEARN_PAYLOAD) - - iam_client = sagemaker_session.boto_session.client("iam") - role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"] - - sklearn_model = SKLearnModel( - model_data=model_data, - role=role_arn, - entry_point=IR_SKLEARN_ENTRY_POINT, - framework_version=IR_SKLEARN_FRAMEWORK_VERSION, - ) - - hyperparameter_ranges = [ - { - "instance_types": CategoricalParameter([cpu_instance_type]), - "TEST_PARAM": CategoricalParameter( - ["TEST_PARAM_VALUE_1", "TEST_PARAM_VALUE_2"] - ), - } - ] - - phases = [ - Phase(duration_in_seconds=300, initial_number_of_users=2, spawn_rate=2), - Phase(duration_in_seconds=300, initial_number_of_users=14, spawn_rate=2), - ] - - model_latency_thresholds = [ - ModelLatencyThreshold(percentile="P95", value_in_milliseconds=100) - ] - - return sklearn_model.right_size( - sample_payload_url=payload_data, - supported_content_types=IR_SKLEARN_CONTENT_TYPE, - framework=IR_SKLEARN_FRAMEWORK, - job_duration_in_seconds=3600, - hyperparameter_ranges=hyperparameter_ranges, - phases=phases, - model_latency_thresholds=model_latency_thresholds, - max_invocations=100, - max_tests=5, - max_parallel_tests=5, - log_level="Quiet", - ) - - except Exception: - sagemaker_session.delete_model(ModelName=sklearn_model.name) - - -@pytest.fixture(scope="module") -def default_right_sized_unregistered_base_model(sagemaker_session, cpu_instance_type): - with timeout(minutes=45): - try: - ir_job_name = unique_name_from_base("test-ir-right-size-job-name") - model_data = sagemaker_session.upload_data(path=IR_SKLEARN_MODEL) - payload_data = sagemaker_session.upload_data(path=IR_SKLEARN_PAYLOAD) - - iam_client = sagemaker_session.boto_session.client("iam") - role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"] - - model = Model(model_data=model_data, role=role_arn, entry_point=IR_SKLEARN_ENTRY_POINT) - - return ( - model.right_size( - job_name=ir_job_name, - sample_payload_url=payload_data, - supported_content_types=IR_SKLEARN_CONTENT_TYPE, - supported_instance_types=[cpu_instance_type], - framework=IR_SKLEARN_FRAMEWORK, - log_level="Quiet", - ), - ir_job_name, - ) - except Exception: - sagemaker_session.delete_model(ModelName=model.name) - - @pytest.mark.slow_test def test_default_right_size_and_deploy_registered_model_sklearn( default_right_sized_model, sagemaker_session @@ -291,72 +176,6 @@ def test_default_right_size_and_deploy_registered_model_sklearn( predictor.delete_endpoint() -@pytest.mark.slow_test -def test_default_right_size_and_deploy_unregistered_model_sklearn( - default_right_sized_unregistered_model, sagemaker_session -): - endpoint_name = unique_name_from_base("test-ir-right-size-default-unregistered-sklearn") - - right_size_model, ir_job_name = default_right_sized_unregistered_model - with timeout(minutes=45): - try: - right_size_model.predictor_cls = SKLearnPredictor - predictor = right_size_model.deploy(endpoint_name=endpoint_name) - - payload = pd.read_csv(IR_SKLEARN_DATA, header=None) - - inference = predictor.predict(payload) - assert inference is not None - assert 26 == len(inference) - finally: - predictor.delete_model() - predictor.delete_endpoint() - - -@pytest.mark.slow_test -def test_default_right_size_and_deploy_unregistered_base_model( - default_right_sized_unregistered_base_model, sagemaker_session -): - endpoint_name = unique_name_from_base("test-ir-right-size-default-unregistered-base") - - right_size_model, ir_job_name = default_right_sized_unregistered_base_model - with timeout(minutes=45): - try: - right_size_model.predictor_cls = SKLearnPredictor - predictor = right_size_model.deploy(endpoint_name=endpoint_name) - - payload = pd.read_csv(IR_SKLEARN_DATA, header=None) - - inference = predictor.predict(payload) - assert inference is not None - assert 26 == len(inference) - finally: - predictor.delete_model() - predictor.delete_endpoint() - - -@pytest.mark.slow_test -def test_advanced_right_size_and_deploy_unregistered_model_sklearn( - advanced_right_sized_unregistered_model, sagemaker_session -): - endpoint_name = unique_name_from_base("test-ir-right-size-advanced-sklearn") - - right_size_model = advanced_right_sized_unregistered_model - with timeout(minutes=45): - try: - right_size_model.predictor_cls = SKLearnPredictor - predictor = right_size_model.deploy(endpoint_name=endpoint_name) - - payload = pd.read_csv(IR_SKLEARN_DATA, header=None) - - inference = predictor.predict(payload) - assert inference is not None - assert 26 == len(inference) - finally: - predictor.delete_model() - predictor.delete_endpoint() - - @pytest.mark.slow_test def test_advanced_right_size_and_deploy_registered_model_sklearn( advanced_right_sized_model, sagemaker_session diff --git a/tests/unit/sagemaker/inference_recommender/test_inference_recommender_mixin.py b/tests/unit/sagemaker/inference_recommender/test_inference_recommender_mixin.py index 50cca2003b..a8aa219dd0 100644 --- a/tests/unit/sagemaker/inference_recommender/test_inference_recommender_mixin.py +++ b/tests/unit/sagemaker/inference_recommender/test_inference_recommender_mixin.py @@ -175,147 +175,6 @@ def default_right_sized_model(model_package): ) -@patch("uuid.uuid4", MagicMock(return_value="sample-unique-uuid")) -def test_right_size_default_with_model_name_successful(sagemaker_session, model): - inference_recommender_model = model.right_size( - sample_payload_url=IR_SAMPLE_PAYLOAD_URL, - supported_content_types=IR_SUPPORTED_CONTENT_TYPES, - supported_instance_types=[IR_SAMPLE_INSTANCE_TYPE], - job_name=IR_JOB_NAME, - framework=IR_SAMPLE_FRAMEWORK, - ) - - assert sagemaker_session.create_model.called_with( - name=ANY, - role=IR_ROLE_ARN, - container_defs=None, - primary_container={}, - vpc_config=None, - enable_network_isolation=False, - ) - - # assert that the create api has been called with default parameters with model name - assert sagemaker_session.create_inference_recommendations_job.called_with( - role=IR_ROLE_ARN, - job_name=IR_JOB_NAME, - job_type="Default", - job_duration_in_seconds=None, - model_name=ANY, - model_package_version_arn=None, - framework=IR_SAMPLE_FRAMEWORK, - framework_version=None, - sample_payload_url=IR_SAMPLE_PAYLOAD_URL, - supported_content_types=IR_SUPPORTED_CONTENT_TYPES, - supported_instance_types=[IR_SAMPLE_INSTANCE_TYPE], - endpoint_configurations=None, - traffic_pattern=None, - stopping_conditions=None, - resource_limit=None, - ) - - assert sagemaker_session.wait_for_inference_recommendations_job.called_with(IR_JOB_NAME) - - # confirm that the IR instance attributes have been set - assert ( - inference_recommender_model.inference_recommender_job_results - == IR_SAMPLE_INFERENCE_RESPONSE - ) - assert ( - inference_recommender_model.inference_recommendations - == IR_SAMPLE_INFERENCE_RESPONSE["InferenceRecommendations"] - ) - - # confirm that the returned object of right_size is itself - assert inference_recommender_model == model - - -@patch("uuid.uuid4", MagicMock(return_value="sample-unique-uuid")) -def test_right_size_advanced_list_instances_model_name_successful(sagemaker_session, model): - inference_recommender_model = model.right_size( - sample_payload_url=IR_SAMPLE_PAYLOAD_URL, - supported_content_types=IR_SUPPORTED_CONTENT_TYPES, - framework="SAGEMAKER-SCIKIT-LEARN", - job_duration_in_seconds=7200, - hyperparameter_ranges=IR_SAMPLE_LIST_OF_INSTANCES_HYPERPARAMETER_RANGES, - phases=IR_SAMPLE_PHASES, - traffic_type="PHASES", - max_invocations=100, - model_latency_thresholds=IR_SAMPLE_MODEL_LATENCY_THRESHOLDS, - max_tests=5, - max_parallel_tests=5, - ) - - # assert that the create api has been called with advanced parameters - assert sagemaker_session.create_inference_recommendations_job.called_with( - role=IR_ROLE_ARN, - job_name=IR_JOB_NAME, - job_type="Advanced", - job_duration_in_seconds=7200, - model_name=ANY, - model_package_version_arn=None, - framework=IR_SAMPLE_FRAMEWORK, - framework_version=None, - sample_payload_url=IR_SAMPLE_PAYLOAD_URL, - supported_content_types=IR_SUPPORTED_CONTENT_TYPES, - supported_instance_types=[IR_SAMPLE_INSTANCE_TYPE], - endpoint_configurations=IR_SAMPLE_ENDPOINT_CONFIG, - traffic_pattern=IR_SAMPLE_TRAFFIC_PATTERN, - stopping_conditions=IR_SAMPLE_STOPPING_CONDITIONS, - resource_limit=IR_SAMPLE_RESOURCE_LIMIT, - ) - - assert sagemaker_session.wait_for_inference_recommendations_job.called_with(IR_JOB_NAME) - - # confirm that the IR instance attributes have been set - assert ( - inference_recommender_model.inference_recommender_job_results - == IR_SAMPLE_INFERENCE_RESPONSE - ) - assert ( - inference_recommender_model.inference_recommendations - == IR_SAMPLE_INFERENCE_RESPONSE["InferenceRecommendations"] - ) - - # confirm that the returned object of right_size is itself - assert inference_recommender_model == model - - -@patch("uuid.uuid4", MagicMock(return_value="sample-unique-uuid")) -def test_right_size_advanced_single_instances_model_name_successful(sagemaker_session, model): - model.right_size( - sample_payload_url=IR_SAMPLE_PAYLOAD_URL, - supported_content_types=IR_SUPPORTED_CONTENT_TYPES, - framework="SAGEMAKER-SCIKIT-LEARN", - job_duration_in_seconds=7200, - hyperparameter_ranges=IR_SAMPLE_SINGLE_INSTANCES_HYPERPARAMETER_RANGES, - phases=IR_SAMPLE_PHASES, - traffic_type="PHASES", - max_invocations=100, - model_latency_thresholds=IR_SAMPLE_MODEL_LATENCY_THRESHOLDS, - max_tests=5, - max_parallel_tests=5, - ) - - # assert that the create api has been called with advanced parameters - assert sagemaker_session.create_inference_recommendations_job.called_with( - role=IR_ROLE_ARN, - job_name=IR_JOB_NAME, - job_type="Advanced", - job_duration_in_seconds=7200, - model_name=ANY, - model_package_version_arn=None, - framework=IR_SAMPLE_FRAMEWORK, - framework_version=None, - sample_payload_url=IR_SAMPLE_PAYLOAD_URL, - supported_content_types=IR_SUPPORTED_CONTENT_TYPES, - supported_instance_types=[IR_SAMPLE_INSTANCE_TYPE], - endpoint_configurations=IR_SAMPLE_ENDPOINT_CONFIG, - traffic_pattern=IR_SAMPLE_TRAFFIC_PATTERN, - stopping_conditions=IR_SAMPLE_STOPPING_CONDITIONS, - resource_limit=IR_SAMPLE_RESOURCE_LIMIT, - ) - - def test_right_size_default_with_model_package_successful(sagemaker_session, model_package): inference_recommender_model_pkg = model_package.right_size( sample_payload_url=IR_SAMPLE_PAYLOAD_URL, @@ -331,7 +190,6 @@ def test_right_size_default_with_model_package_successful(sagemaker_session, mod job_name=IR_JOB_NAME, job_type="Default", job_duration_in_seconds=None, - model_name=None, model_package_version_arn=model_package.model_package_arn, framework=IR_SAMPLE_FRAMEWORK, framework_version=None, @@ -344,7 +202,7 @@ def test_right_size_default_with_model_package_successful(sagemaker_session, mod resource_limit=None, ) - assert sagemaker_session.wait_for_inference_recommendations_job.called_with(IR_JOB_NAME) + assert sagemaker_session.wait_for_inference_recomendations_job.called_with(IR_JOB_NAME) # confirm that the IR instance attributes have been set assert ( @@ -395,7 +253,7 @@ def test_right_size_advanced_list_instances_model_package_successful( resource_limit=IR_SAMPLE_RESOURCE_LIMIT, ) - assert sagemaker_session.wait_for_inference_recommendations_job.called_with(IR_JOB_NAME) + assert sagemaker_session.wait_for_inference_recomendations_job.called_with(IR_JOB_NAME) # confirm that the IR instance attributes have been set assert ( @@ -501,6 +359,21 @@ def test_right_size_invalid_hyperparameter_ranges(sagemaker_session, model_packa ) +# TODO -> removed once model registry is decoupled +def test_right_size_missing_model_package_arn(sagemaker_session, model): + with pytest.raises( + ValueError, + match="right_size\\(\\) is currently only supported with a registered model", + ): + model.right_size( + sample_payload_url=IR_SAMPLE_PAYLOAD_URL, + supported_content_types=IR_SUPPORTED_CONTENT_TYPES, + supported_instance_types=[IR_SAMPLE_INSTANCE_TYPE], + job_name=IR_JOB_NAME, + framework=IR_SAMPLE_FRAMEWORK, + ) + + # TODO check our framework mapping when we add in inference_recommendation_id support diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 49cf8ad5c0..b9be3fb285 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -3194,7 +3194,6 @@ def test_batch_get_record(sagemaker_session): IR_MODEL_PACKAGE_VERSION_ARN = ( "arn:aws:sagemaker:us-west-2:123456789123:model-package/unit-test-package-version/1" ) -IR_MODEL_NAME = "MODEL_NAME" IR_NEAREST_MODEL_NAME = "xgboost" IR_SUPPORTED_INSTANCE_TYPES = ["ml.c5.xlarge", "ml.c5.2xlarge"] IR_FRAMEWORK = "XGBOOST" @@ -3245,30 +3244,6 @@ def create_inference_recommendations_job_default_happy_response(): } -def create_inference_recommendations_job_default_model_name_happy_response(): - return { - "JobName": IR_USER_JOB_NAME, - "JobType": "Default", - "RoleArn": IR_ROLE_ARN, - "InputConfig": { - "ContainerConfig": { - "Domain": "MACHINE_LEARNING", - "Task": "OTHER", - "Framework": IR_FRAMEWORK, - "PayloadConfig": { - "SamplePayloadUrl": IR_SAMPLE_PAYLOAD_URL, - "SupportedContentTypes": IR_SUPPORTED_CONTENT_TYPES, - }, - "FrameworkVersion": IR_FRAMEWORK_VERSION, - "NearestModelName": IR_NEAREST_MODEL_NAME, - "SupportedInstanceTypes": IR_SUPPORTED_INSTANCE_TYPES, - }, - "ModelName": IR_MODEL_NAME, - }, - "JobDescription": "#python-sdk-create", - } - - def create_inference_recommendations_job_advanced_happy_response(): base_advanced_job_response = create_inference_recommendations_job_default_happy_response() @@ -3283,22 +3258,6 @@ def create_inference_recommendations_job_advanced_happy_response(): return base_advanced_job_response -def create_inference_recommendations_job_advanced_model_name_happy_response(): - base_advanced_job_response = ( - create_inference_recommendations_job_default_model_name_happy_response() - ) - - base_advanced_job_response["JobName"] = IR_JOB_NAME - base_advanced_job_response["JobType"] = IR_ADVANCED_JOB - base_advanced_job_response["StoppingConditions"] = IR_STOPPING_CONDITIONS - base_advanced_job_response["InputConfig"]["JobDurationInSeconds"] = IR_JOB_DURATION_IN_SECONDS - base_advanced_job_response["InputConfig"]["EndpointConfigurations"] = IR_ENDPOINT_CONFIGURATIONS - base_advanced_job_response["InputConfig"]["TrafficPattern"] = IR_TRAFFIC_PATTERN - base_advanced_job_response["InputConfig"]["ResourceLimit"] = IR_RESOURCE_LIMIT - - return base_advanced_job_response - - def test_create_inference_recommendations_job_default_happy(sagemaker_session): job_name = sagemaker_session.create_inference_recommendations_job( role=IR_ROLE_ARN, @@ -3345,92 +3304,6 @@ def test_create_inference_recommendations_job_advanced_happy(sagemaker_session): assert IR_JOB_NAME == job_name -def test_create_inference_recommendations_job_default_model_name_happy(sagemaker_session): - job_name = sagemaker_session.create_inference_recommendations_job( - role=IR_ROLE_ARN, - sample_payload_url=IR_SAMPLE_PAYLOAD_URL, - supported_content_types=IR_SUPPORTED_CONTENT_TYPES, - model_name=IR_MODEL_NAME, - model_package_version_arn=None, - framework=IR_FRAMEWORK, - framework_version=IR_FRAMEWORK_VERSION, - nearest_model_name=IR_NEAREST_MODEL_NAME, - supported_instance_types=IR_SUPPORTED_INSTANCE_TYPES, - job_name=IR_USER_JOB_NAME, - ) - - sagemaker_session.sagemaker_client.create_inference_recommendations_job.assert_called_with( - **create_inference_recommendations_job_default_model_name_happy_response() - ) - - assert IR_USER_JOB_NAME == job_name - - -@patch("uuid.uuid4", MagicMock(return_value="sample-unique-uuid")) -def test_create_inference_recommendations_job_advanced_model_name_happy(sagemaker_session): - job_name = sagemaker_session.create_inference_recommendations_job( - role=IR_ROLE_ARN, - sample_payload_url=IR_SAMPLE_PAYLOAD_URL, - supported_content_types=IR_SUPPORTED_CONTENT_TYPES, - model_name=IR_MODEL_NAME, - model_package_version_arn=None, - framework=IR_FRAMEWORK, - framework_version=IR_FRAMEWORK_VERSION, - nearest_model_name=IR_NEAREST_MODEL_NAME, - supported_instance_types=IR_SUPPORTED_INSTANCE_TYPES, - endpoint_configurations=IR_ENDPOINT_CONFIGURATIONS, - traffic_pattern=IR_TRAFFIC_PATTERN, - stopping_conditions=IR_STOPPING_CONDITIONS, - resource_limit=IR_RESOURCE_LIMIT, - job_type=IR_ADVANCED_JOB, - job_duration_in_seconds=IR_JOB_DURATION_IN_SECONDS, - ) - - sagemaker_session.sagemaker_client.create_inference_recommendations_job.assert_called_with( - **create_inference_recommendations_job_advanced_model_name_happy_response() - ) - - assert IR_JOB_NAME == job_name - - -def test_create_inference_recommendations_job_missing_model_name_and_pkg(sagemaker_session): - with pytest.raises( - ValueError, - match="Please provide either model_name or model_package_version_arn.", - ): - sagemaker_session.create_inference_recommendations_job( - role=IR_ROLE_ARN, - sample_payload_url=IR_SAMPLE_PAYLOAD_URL, - supported_content_types=IR_SUPPORTED_CONTENT_TYPES, - model_name=None, - model_package_version_arn=None, - framework=IR_FRAMEWORK, - framework_version=IR_FRAMEWORK_VERSION, - nearest_model_name=IR_NEAREST_MODEL_NAME, - supported_instance_types=IR_SUPPORTED_INSTANCE_TYPES, - job_name=IR_USER_JOB_NAME, - ) - - -def test_create_inference_recommendations_job_provided_model_name_and_pkg(sagemaker_session): - with pytest.raises( - ValueError, - match="Please provide either model_name or model_package_version_arn.", - ): - sagemaker_session.create_inference_recommendations_job( - role=IR_ROLE_ARN, - sample_payload_url=IR_SAMPLE_PAYLOAD_URL, - supported_content_types=IR_SUPPORTED_CONTENT_TYPES, - model_name=IR_MODEL_NAME, - model_package_version_arn=IR_MODEL_PACKAGE_VERSION_ARN, - framework=IR_FRAMEWORK, - framework_version=IR_FRAMEWORK_VERSION, - nearest_model_name=IR_NEAREST_MODEL_NAME, - supported_instance_types=IR_SUPPORTED_INSTANCE_TYPES, - job_name=IR_USER_JOB_NAME, - ) - - def test_create_inference_recommendations_job_propogate_validation_exception(sagemaker_session): validation_exception_message = ( "Failed to describe model due to validation failure with following error: test_error"