diff --git a/doc/api/inference/model.rst b/doc/api/inference/model.rst index d6cb0b5003..038f34b953 100644 --- a/doc/api/inference/model.rst +++ b/doc/api/inference/model.rst @@ -5,6 +5,7 @@ Model :members: :undoc-members: :show-inheritance: + :inherited-members: .. autoclass:: sagemaker.model.FrameworkModel :members: diff --git a/src/sagemaker/inference_recommender/__init__.py b/src/sagemaker/inference_recommender/__init__.py new file mode 100644 index 0000000000..c1776fa899 --- /dev/null +++ b/src/sagemaker/inference_recommender/__init__.py @@ -0,0 +1,14 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Classes for using Inference Recommender with Amazon SageMaker.""" +from __future__ import absolute_import diff --git a/src/sagemaker/inference_recommender/inference_recommender_mixin.py b/src/sagemaker/inference_recommender/inference_recommender_mixin.py new file mode 100644 index 0000000000..f90f5d19e2 --- /dev/null +++ b/src/sagemaker/inference_recommender/inference_recommender_mixin.py @@ -0,0 +1,294 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Placeholder docstring""" +from __future__ import absolute_import + +import logging + +from typing import List, Dict, Optional + +import sagemaker + +from sagemaker.parameter import CategoricalParameter + +INFERENCE_RECOMMENDER_FRAMEWORK_MAPPING = { + "xgboost": "XGBOOST", + "sklearn": "SAGEMAKER-SCIKIT-LEARN", + "pytorch": "PYTORCH", + "tensorflow": "TENSORFLOW", + "mxnet": "MXNET", +} + +LOGGER = logging.getLogger("sagemaker") + + +class Phase: + """Used to store phases of a traffic pattern to perform endpoint load testing. + + Required for an Advanced Inference Recommendations Job + """ + + def __init__(self, duration_in_seconds: int, initial_number_of_users: int, spawn_rate: int): + """Initialze a `Phase`""" + self.to_json = { + "DurationInSeconds": duration_in_seconds, + "InitialNumberOfUsers": initial_number_of_users, + "SpawnRate": spawn_rate, + } + + +class ModelLatencyThreshold: + """Used to store inference request/response latency to perform endpoint load testing. + + Required for an Advanced Inference Recommendations Job + """ + + def __init__(self, percentile: str, value_in_milliseconds: int): + """Initialze a `ModelLatencyThreshold`""" + self.to_json = {"Percentile": percentile, "ValueInMilliseconds": value_in_milliseconds} + + +class InferenceRecommenderMixin: + """A mixin class for SageMaker ``Inference Recommender`` that will be extended by ``Model``""" + + def right_size( + self, + sample_payload_url: str = None, + supported_content_types: List[str] = None, + supported_instance_types: List[str] = None, + job_name: str = None, + framework: str = None, + job_duration_in_seconds: int = None, + hyperparameter_ranges: List[Dict[str, CategoricalParameter]] = None, + phases: List[Phase] = None, + traffic_type: str = None, + max_invocations: int = None, + model_latency_thresholds: List[ModelLatencyThreshold] = None, + max_tests: int = None, + max_parallel_tests: int = None, + log_level: Optional[str] = "Verbose", + ): + """Recommends an instance type for a SageMaker or BYOC model. + + Args: + sample_payload_url (str): The S3 path where the sample payload is stored. + supported_content_types: (list[str]): The supported MIME types for the input data. + supported_instance_types (list[str]): A list of the instance types that this model + is expected to work on. (default: None). + job_name (str): The name of the Inference Recommendations Job. (default: None). + framework (str): The machine learning framework of the Image URI. + Only required to specify if you bring your own custom containers (default: None). + job_duration_in_seconds (int): The maximum job duration that a job can run for. + (default: None). + hyperparameter_ranges (list[Dict[str, sagemaker.parameter.CategoricalParameter]]): + Specifies the hyper parameters to be used during endpoint load tests. + `instance_type` must be specified as a hyperparameter range. + `env_vars` can be specified as an optional hyperparameter range. (default: None). + Example:: + + hyperparameter_ranges = [{ + 'instance_types': CategoricalParameter(['ml.c5.xlarge', 'ml.c5.2xlarge']), + 'OMP_NUM_THREADS': CategoricalParameter(['1', '2', '3', '4']) + }] + + phases (list[Phase]): Specifies the criteria for increasing load + during endpoint load tests. (default: None). + traffic_type (str): Specifies the traffic type that matches the phases. (default: None). + max_invocations (str): defines invocation limit for endpoint load tests (default: None). + model_latency_thresholds (list[ModelLatencyThreshold]): defines the response latency + thresholds for endpoint load tests (default: None). + max_tests (int): restricts how many endpoints are allowed to be + spun up for this job (default: None). + max_parallel_tests (int): restricts how many concurrent endpoints + this job is allowed to spin up (default: None). + log_level (str): specifies the inline output when waiting for right_size to complete + (default: "Verbose"). + + Returns: + sagemaker.model.Model: A SageMaker ``Model`` object. See + :func:`~sagemaker.model.Model` for full details. + """ + if not isinstance(self, sagemaker.model.ModelPackage): + raise ValueError("right_size() is currently only supported with a registered model") + + if not framework and self._framework(): + framework = INFERENCE_RECOMMENDER_FRAMEWORK_MAPPING.get(self._framework, framework) + + framework_version = self._get_framework_version() + + endpoint_configurations = self._convert_to_endpoint_configurations_json( + hyperparameter_ranges=hyperparameter_ranges + ) + traffic_pattern = self._convert_to_traffic_pattern_json( + traffic_type=traffic_type, phases=phases + ) + stopping_conditions = self._convert_to_stopping_conditions_json( + max_invocations=max_invocations, model_latency_thresholds=model_latency_thresholds + ) + resource_limit = self._convert_to_resource_limit_json( + max_tests=max_tests, max_parallel_tests=max_parallel_tests + ) + + if endpoint_configurations or traffic_pattern or stopping_conditions or resource_limit: + LOGGER.info("Advance Job parameters were specified. Running Advanced job...") + job_type = "Advanced" + else: + LOGGER.info("Advance Job parameters were not specified. Running Default job...") + job_type = "Default" + + self._init_sagemaker_session_if_does_not_exist() + + ret_name = self.sagemaker_session.create_inference_recommendations_job( + role=self.role, + job_name=job_name, + job_type=job_type, + job_duration_in_seconds=job_duration_in_seconds, + model_package_version_arn=self.model_package_arn, + framework=framework, + framework_version=framework_version, + sample_payload_url=sample_payload_url, + supported_content_types=supported_content_types, + supported_instance_types=supported_instance_types, + endpoint_configurations=endpoint_configurations, + traffic_pattern=traffic_pattern, + stopping_conditions=stopping_conditions, + resource_limit=resource_limit, + ) + + self.inference_recommender_job_results = ( + self.sagemaker_session.wait_for_inference_recommendations_job( + ret_name, log_level=log_level + ) + ) + self.inference_recommendations = self.inference_recommender_job_results.get( + "InferenceRecommendations" + ) + + return self + + def _check_inference_recommender_args( + self, + instance_type=None, + initial_instance_count=None, + accelerator_type=None, + serverless_inference_config=None, + async_inference_config=None, + ): + """Validates that Inference Recommendation parameters can be used in `model.deploy()` + + Args: + instance_type (str): The initial number of instances to run + in the ``Endpoint`` created from this ``Model``. If not using + serverless inference or the model has not called ``right_size()``, + then it need to be a number larger or equals + to 1 (default: None) + initial_instance_count (int):The EC2 instance type to deploy this Model to. + For example, 'ml.p2.xlarge', or 'local' for local mode. If not using + serverless inference or the model has not called ``right_size()``, + then it is required to deploy a model. + (default: None) + accelerator_type (str): whether accelerator_type has been passed into `model.deploy()`. + serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig)): + whether serverless_inference_config has been passed into `model.deploy()`. + async_inference_config (sagemaker.model_monitor.AsyncInferenceConfig): + whether async_inference_config has been passed into `model.deploy()`. + + Returns: + (string, int) or None: Top instance_type and associated initial_instance_count + if self.inference_recommender_job_results has been generated. Otherwise, return None. + """ + if accelerator_type: + raise ValueError("accelerator_type is not compatible with right_size().") + if instance_type or initial_instance_count: + LOGGER.warning( + "instance_type or initial_instance_count specified." + "Overriding right_size() recommendations." + ) + return None + if async_inference_config: + LOGGER.warning( + "async_inference_config is specified. Overriding right_size() recommendations." + ) + return None + if serverless_inference_config: + LOGGER.warning( + "serverless_inference_config is specified. Overriding right_size() recommendations." + ) + return None + + instance_type = self.inference_recommendations[0]["EndpointConfiguration"]["InstanceType"] + initial_instance_count = self.inference_recommendations[0]["EndpointConfiguration"][ + "InitialInstanceCount" + ] + return (instance_type, initial_instance_count) + + def _convert_to_endpoint_configurations_json( + self, hyperparameter_ranges: List[Dict[str, CategoricalParameter]] + ): + """Bundle right_size() parameters into an endpoint configuration for Advanced job""" + if not hyperparameter_ranges: + return None + + endpoint_configurations_to_json = [] + for parameter_range in hyperparameter_ranges: + if not parameter_range.get("instance_types"): + raise ValueError("instance_type must be defined as a hyperparameter_range") + parameter_range = parameter_range.copy() + instance_types = parameter_range.get("instance_types").values + parameter_range.pop("instance_types") + + for instance_type in instance_types: + parameter_ranges = [] + for name, param in parameter_range.items(): + as_json = param.as_json_range(name) + as_json["Value"] = as_json.pop("Values") + parameter_ranges.append(as_json) + endpoint_configurations_to_json.append( + { + "EnvironmentParameterRanges": { + "CategoricalParameterRanges": parameter_ranges + }, + "InstanceType": instance_type, + } + ) + + return endpoint_configurations_to_json + + def _convert_to_traffic_pattern_json(self, traffic_type: str, phases: List[Phase]): + """Bundle right_size() parameters into a traffic pattern for Advanced job""" + if not phases: + return None + return { + "Phases": [phase.to_json for phase in phases], + "TrafficType": traffic_type if traffic_type else "PHASES", + } + + def _convert_to_resource_limit_json(self, max_tests: int, max_parallel_tests: int): + """Bundle right_size() parameters into a resource limit for Advanced job""" + if not max_tests and not max_parallel_tests: + return None + return { + "MaxNumberOfTests": max_tests, + "MaxParallelOfTests": max_parallel_tests, + } + + def _convert_to_stopping_conditions_json( + self, max_invocations: int, model_latency_thresholds: List[ModelLatencyThreshold] + ): + """Bundle right_size() parameters into stopping conditions for Advanced job""" + if not max_invocations and not model_latency_thresholds: + return None + return { + "MaxInvocations": max_invocations, + "ModelLatencyThresholds": [threshold.to_json for threshold in model_latency_thresholds], + } diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index e04b83a14f..a8d6347d8c 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -48,6 +48,7 @@ from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.pipeline_context import runnable_by_pipeline, PipelineSession +from sagemaker.inference_recommender.inference_recommender_mixin import InferenceRecommenderMixin LOGGER = logging.getLogger("sagemaker") @@ -83,7 +84,7 @@ def delete_model(self, *args, **kwargs) -> None: SAGEMAKER_OUTPUT_LOCATION = "sagemaker_s3_output" -class Model(ModelBase): +class Model(ModelBase, InferenceRecommenderMixin): """A SageMaker ``Model`` that can be deployed to an ``Endpoint``.""" def __init__( @@ -279,6 +280,8 @@ def __init__( self._is_compiled_model = False self._compilation_job_name = None self._is_edge_packaged_model = False + self.inference_recommender_job_results = None + self.inference_recommendations = None self._enable_network_isolation = enable_network_isolation self.model_kms_key = model_kms_key self.image_config = image_config @@ -1050,11 +1053,13 @@ def deploy( Args: initial_instance_count (int): The initial number of instances to run in the ``Endpoint`` created from this ``Model``. If not using - serverless inference, then it need to be a number larger or equals + serverless inference or the model has not called ``right_size()``, + then it need to be a number larger or equals to 1 (default: None) instance_type (str): The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge', or 'local' for local mode. If not using - serverless inference, then it is required to deploy a model. + serverless inference or the model has not called ``right_size()``, + then it is required to deploy a model. (default: None) serializer (:class:`~sagemaker.serializers.BaseSerializer`): A serializer object, used to encode data for an inference endpoint @@ -1118,6 +1123,18 @@ def deploy( is not None. Otherwise, return None. """ removed_kwargs("update_endpoint", kwargs) + + if self.inference_recommender_job_results: + inference_recommendation = self._check_inference_recommender_args( + instance_type, + initial_instance_count, + accelerator_type, + serverless_inference_config, + async_inference_config, + ) + if inference_recommendation: + instance_type, initial_instance_count = inference_recommendation + self._init_sagemaker_session_if_does_not_exist(instance_type) tags = add_jumpstart_tags( diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 5404978200..0df2996352 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -21,6 +21,7 @@ import time import typing import warnings +import uuid from typing import List, Dict, Any, Sequence, Optional import boto3 @@ -53,10 +54,12 @@ _STATUS_CODE_TABLE = { "COMPLETED": "Completed", "INPROGRESS": "InProgress", + "IN_PROGRESS": "InProgress", "FAILED": "Failed", "STOPPED": "Stopped", "STOPPING": "Stopping", "STARTING": "Starting", + "PENDING": "Pending", } @@ -4655,6 +4658,231 @@ def _intercept_create_request( """ return create(request) + def _create_inference_recommendations_job_request( + self, + role: str, + job_name: str, + job_description: str, + framework: str, + sample_payload_url: str, + supported_content_types: List[str], + model_package_version_arn: str = None, + job_duration_in_seconds: int = None, + job_type: str = "Default", + framework_version: str = None, + nearest_model_name: str = None, + supported_instance_types: List[str] = None, + endpoint_configurations: List[Dict[str, Any]] = None, + traffic_pattern: Dict[str, Any] = None, + stopping_conditions: Dict[str, Any] = None, + resource_limit: Dict[str, Any] = None, + ) -> Dict[str, Any]: + """Get request dictionary for CreateInferenceRecommendationsJob API. + + Args: + role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training + jobs and APIs that create Amazon SageMaker endpoints use this role to access + training data and model artifacts. + You must grant sufficient permissions to this role. + job_name (str): The name of the Inference Recommendations Job. + job_description (str): A description of the Inference Recommendations Job. + framework (str): The machine learning framework of the Image URI. + sample_payload_url (str): The S3 path where the sample payload is stored. + supported_content_types (List[str]): The supported MIME types for the input data. + model_package_version_arn (str): The Amazon Resource Name (ARN) of a + versioned model package. + job_duration_in_seconds (int): The maximum job duration that a job + can run for. Will be used for `Advanced` jobs. + job_type (str): The type of job being run. Must either be `Default` or `Advanced`. + framework_version (str): The framework version of the Image URI. + nearest_model_name (str): The name of a pre-trained machine learning model + benchmarked by Amazon SageMaker Inference Recommender that matches your model. + supported_instance_types (List[str]): A list of the instance types that are used + to generate inferences in real-time. + endpoint_configurations (List[Dict[str, any]]): Specifies the endpoint configurations + to use for a job. Will be used for `Advanced` jobs. + traffic_pattern (Dict[str, any]): Specifies the traffic pattern for the job. + Will be used for `Advanced` jobs. + stopping_conditions (Dict[str, any]): A set of conditions for stopping a + recommendation job. + If any of the conditions are met, the job is automatically stopped. + Will be used for `Advanced` jobs. + resource_limit (Dict[str, any]): Defines the resource limit for the job. + Will be used for `Advanced` jobs. + Returns: + Dict[str, Any]: request dictionary for the CreateInferenceRecommendationsJob API + """ + + containerConfig = { + "Domain": "MACHINE_LEARNING", + "Task": "OTHER", + "Framework": framework, + "PayloadConfig": { + "SamplePayloadUrl": sample_payload_url, + "SupportedContentTypes": supported_content_types, + }, + } + + if framework_version: + containerConfig["FrameworkVersion"] = framework_version + if nearest_model_name: + containerConfig["NearestModelName"] = nearest_model_name + if supported_instance_types: + containerConfig["SupportedInstanceTypes"] = supported_instance_types + + request = { + "JobName": job_name, + "JobType": job_type, + "RoleArn": role, + "InputConfig": { + "ContainerConfig": containerConfig, + "ModelPackageVersionArn": model_package_version_arn, + }, + } + + if job_description: + request["JobDescription"] = job_description + if job_duration_in_seconds: + request["InputConfig"]["JobDurationInSeconds"] = job_duration_in_seconds + + if job_type == "Advanced": + if stopping_conditions: + request["StoppingConditions"] = stopping_conditions + if resource_limit: + request["InputConfig"]["ResourceLimit"] = resource_limit + if traffic_pattern: + request["InputConfig"]["TrafficPattern"] = traffic_pattern + if endpoint_configurations: + request["InputConfig"]["EndpointConfigurations"] = endpoint_configurations + + return request + + def create_inference_recommendations_job( + self, + role: str, + sample_payload_url: str, + supported_content_types: List[str], + job_name: str = None, + job_type: str = "Default", + model_package_version_arn: str = None, + job_duration_in_seconds: int = None, + nearest_model_name: str = None, + supported_instance_types: List[str] = None, + framework: str = None, + framework_version: str = None, + endpoint_configurations: List[Dict[str, any]] = None, + traffic_pattern: Dict[str, any] = None, + stopping_conditions: Dict[str, any] = None, + resource_limit: Dict[str, any] = None, + ): + """Creates an Inference Recommendations Job + + Args: + role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training + jobs and APIs that create Amazon SageMaker endpoints use this role to access + training data and model artifacts. + You must grant sufficient permissions to this role. + sample_payload_url (str): The S3 path where the sample payload is stored. + supported_content_types (List[str]): The supported MIME types for the input data. + model_package_version_arn (str): The Amazon Resource Name (ARN) of a + versioned model package. + job_name (str): The name of the job being run. + job_type (str): The type of job being run. Must either be `Default` or `Advanced`. + job_duration_in_seconds (int): The maximum job duration that a job + can run for. Will be used for `Advanced` jobs. + nearest_model_name (str): The name of a pre-trained machine learning model + benchmarked by Amazon SageMaker Inference Recommender that matches your model. + supported_instance_types (List[str]): A list of the instance types that are used + to generate inferences in real-time. + framework (str): The machine learning framework of the Image URI. + framework_version (str): The framework version of the Image URI. + endpoint_configurations (List[Dict[str, any]]): Specifies the endpoint configurations + to use for a job. Will be used for `Advanced` jobs. + traffic_pattern (Dict[str, any]): Specifies the traffic pattern for the job. + Will be used for `Advanced` jobs. + stopping_conditions (Dict[str, any]): A set of conditions for stopping a + recommendation job. + If any of the conditions are met, the job is automatically stopped. + Will be used for `Advanced` jobs. + resource_limit (Dict[str, any]): Defines the resource limit for the job. + Will be used for `Advanced` jobs. + Returns: + str: The name of the job created. In the form of `SMPYTHONSDK-` + """ + + if not job_name: + unique_tail = uuid.uuid4() + job_name = "SMPYTHONSDK-" + str(unique_tail) + job_description = "#python-sdk-create" + + create_inference_recommendations_job_request = ( + self._create_inference_recommendations_job_request( + role=role, + model_package_version_arn=model_package_version_arn, + job_name=job_name, + job_type=job_type, + job_duration_in_seconds=job_duration_in_seconds, + job_description=job_description, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + sample_payload_url=sample_payload_url, + supported_content_types=supported_content_types, + supported_instance_types=supported_instance_types, + endpoint_configurations=endpoint_configurations, + traffic_pattern=traffic_pattern, + stopping_conditions=stopping_conditions, + resource_limit=resource_limit, + ) + ) + + def submit(request): + LOGGER.info("Creating Inference Recommendations job with name: %s", job_name) + LOGGER.debug("process request: %s", json.dumps(request, indent=4)) + self.sagemaker_client.create_inference_recommendations_job(**request) + + self._intercept_create_request( + create_inference_recommendations_job_request, + submit, + self.create_inference_recommendations_job.__name__, + ) + return job_name + + def wait_for_inference_recommendations_job( + self, job_name: str, poll: int = 120, log_level: str = "Verbose" + ) -> Dict[str, Any]: + """Wait for an Amazon SageMaker Inference Recommender job to complete. + + Args: + job_name (str): Name of the Inference Recommender job to wait for. + poll (int): Polling interval in seconds (default: 120). + log_level (str): The level of verbosity for the logs. + Can be "Quiet" or "Verbose" (default: "Quiet"). + + Returns: + (dict): Return value from the ``DescribeInferenceRecommendationsJob`` API. + + Raises: + exceptions.CapacityError: If the Inference Recommender job fails with CapacityError. + exceptions.UnexpectedStatusException: If the Inference Recommender job fails. + """ + if log_level == "Quiet": + _wait_until( + lambda: _describe_inference_recommendations_job_status( + self.sagemaker_client, job_name + ), + poll, + ) + elif log_level == "Verbose": + _display_inference_recommendations_job_steps_status( + self, self.sagemaker_client, job_name + ) + else: + raise ValueError("log_level must be either Quiet or Verbose") + desc = _describe_inference_recommendations_job_status(self.sagemaker_client, job_name) + self._check_job_status(job_name, desc, "Status") + return desc + def get_model_package_args( content_types, @@ -5276,6 +5504,118 @@ def _create_model_package_status(sagemaker_client, model_package_name): return desc +def _describe_inference_recommendations_job_status(sagemaker_client, job_name: str): + """Describes the status of a job and returns the job description. + + Args: + sagemaker_client (boto3.client.sagemaker): A SageMaker client. + job_name (str): The name of the job. + + Returns: + dict: The job description, or None if the job is still in progress. + """ + inference_recommendations_job_status_codes = { + "PENDING": ".", + "IN_PROGRESS": ".", + "COMPLETED": "!", + "FAILED": "*", + "STOPPING": "_", + "STOPPED": "s", + } + in_progress_statuses = {"PENDING", "IN_PROGRESS", "STOPPING"} + + desc = sagemaker_client.describe_inference_recommendations_job(JobName=job_name) + status = desc["Status"] + + print(inference_recommendations_job_status_codes.get(status, "?"), end="", flush=True) + + if status in in_progress_statuses: + return None + + print("") + return desc + + +def _display_inference_recommendations_job_steps_status( + sagemaker_session, sagemaker_client, job_name: str, poll: int = 60 +): + """Placeholder docstring""" + cloudwatch_client = sagemaker_session.boto_session.client("logs") + in_progress_statuses = {"PENDING", "IN_PROGRESS", "STOPPING"} + log_group_name = "/aws/sagemaker/InferenceRecommendationsJobs" + log_stream_name = job_name + "/execution" + + initial_logs_batch = get_log_events_for_inference_recommender( + cloudwatch_client, log_group_name, log_stream_name + ) + print(f"Retrieved logStream: {log_stream_name} from logGroup: {log_group_name}", flush=True) + events = initial_logs_batch["events"] + print(*[event["message"] for event in events], sep="\n", flush=True) + + next_forward_token = initial_logs_batch["nextForwardToken"] if events else None + flush_remaining = True + while True: + logs_batch = ( + cloudwatch_client.get_log_events( + logGroupName=log_group_name, + logStreamName=log_stream_name, + nextToken=next_forward_token, + ) + if next_forward_token + else cloudwatch_client.get_log_events( + logGroupName=log_group_name, logStreamName=log_stream_name + ) + ) + + events = logs_batch["events"] + + desc = sagemaker_client.describe_inference_recommendations_job(JobName=job_name) + status = desc["Status"] + + if not events: + if status in in_progress_statuses: + time.sleep(poll) + continue + if flush_remaining: + flush_remaining = False + time.sleep(poll) + continue + + next_forward_token = logs_batch["nextForwardToken"] + print(*[event["message"] for event in events], sep="\n", flush=True) + + if status not in in_progress_statuses: + break + + time.sleep(poll) + + +def get_log_events_for_inference_recommender(cw_client, log_group_name, log_stream_name): + """Retrieves log events from the specified CloudWatch log group and log stream. + + Args: + cw_client (boto3.client): A boto3 CloudWatch client. + log_group_name (str): The name of the CloudWatch log group. + log_stream_name (str): The name of the CloudWatch log stream. + + Returns: + (dict): A dictionary containing log events from CloudWatch log group and log stream. + """ + print("Fetching logs from CloudWatch...", flush=True) + for _ in retries( + max_retry_count=30, # 30*10 = 5min + exception_message_prefix="Waiting for cloudwatch stream to appear. ", + seconds_to_sleep=10, + ): + try: + return cw_client.get_log_events( + logGroupName=log_group_name, logStreamName=log_stream_name + ) + except ClientError as e: + if e.response["Error"]["Code"] == "ResourceNotFoundException": + pass + + def _deploy_done(sagemaker_client, endpoint_name): """Placeholder docstring""" hosting_status_codes = { diff --git a/tests/data/inference_recommender/inference.py b/tests/data/inference_recommender/inference.py new file mode 100644 index 0000000000..05557e84d2 --- /dev/null +++ b/tests/data/inference_recommender/inference.py @@ -0,0 +1,47 @@ +from __future__ import absolute_import + +import argparse +import joblib +import os + +import pandas as pd +from sklearn.linear_model import LogisticRegression + + +# inference functions --------------- +def model_fn(model_dir): + clf = joblib.load(os.path.join(model_dir, "model.joblib")) + return clf + + +if __name__ == "__main__": + + print("extracting arguments") + parser = argparse.ArgumentParser() + + # Data, model, and output directories + parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR")) + parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN")) + parser.add_argument("--test", type=str, default=os.environ.get("SM_CHANNEL_TEST")) + parser.add_argument("--train-file", type=str, default="candybar_train.csv") + parser.add_argument("--test-file", type=str, default="candybar_test.csv") + + args, _ = parser.parse_known_args() + + print("reading data") + X_train = pd.read_csv(os.path.join(args.train, args.train_file)) + y_train = pd.read_csv(os.path.join(args.test, args.test_file)) + + # train + print("training model") + model = LogisticRegression() + + X_train = X_train.iloc[:, 1:] + y_train = y_train.iloc[:, 1:] + + model.fit(X_train, y_train) + + # persist model + path = os.path.join(args.model_dir, "model.joblib") + joblib.dump(model, path) + print("model persisted at " + path) diff --git a/tests/data/inference_recommender/sample.csv b/tests/data/inference_recommender/sample.csv new file mode 100644 index 0000000000..e7f0e133e4 --- /dev/null +++ b/tests/data/inference_recommender/sample.csv @@ -0,0 +1,26 @@ +1.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,1.000000000000000000e+00,8.331627499112128632e-01,6.632124490858813948e-01,7.148178578717218068e-01 +1.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,1.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,3.091095150278304060e-01,7.834197017262206630e-01,7.131347755709654956e-01 +1.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,1.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,1.000000000000000000e+00,6.632124490858813948e-01,8.170827192227857472e-01 +1.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,1.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,3.091095150278304060e-01,9.398963425917477021e-01,5.196588078496671148e-01 +0.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,6.069600636020484608e-01,3.253885873016726937e-01,1.955511455917782193e-01 +0.000000000000000000e+00,0.000000000000000000e+00,1.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,4.165813903087285386e-01,3.253885873016726937e-01,1.585107651097273918e-01 +0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,1.000000000000000000e+00,9.160696400423672392e-01,3.253885873016726937e-01,2.521363061302601682e-01 +0.000000000000000000e+00,1.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,1.545547626316225720e-01,1.088082890278933845e-01,7.869973295029364380e-02 +0.000000000000000000e+00,1.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,8.567042038219898625e-01,3.253885873016726937e-01,5.330644235245095564e-01 +1.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,1.000000000000000000e+00,0.000000000000000000e+00,1.000000000000000000e+00,8.812691889717336746e-01,8.673574936025130189e-01,4.386295354354305953e-01 +1.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,1.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,4.646878293692493500e-01,7.834197017262206630e-01,4.519677338682178691e-01 +0.000000000000000000e+00,1.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,1.000000000000000000e+00,5.936540346705024285e-02,1.088082890278933845e-01,6.061178895604173444e-01 +1.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,3.091095150278304060e-01,5.181346889312467008e-01,3.340661381286635923e-01 +0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,4.646878293692493500e-01,7.834197017262206630e-01,3.863493270238224087e-01 +0.000000000000000000e+00,1.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,1.000000000000000000e+00,4.646878293692493500e-01,4.704663163682247240e-01,4.692170232456173151e-01 +0.000000000000000000e+00,1.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,5.711361529403511383e-01,4.870466063518485988e-02,1.965442621488194819e-01 +1.000000000000000000e+00,0.000000000000000000e+00,1.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,1.000000000000000000e+00,2.978505588096327372e-01,5.181346889312467008e-01,5.283673434313520545e-01 +1.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,1.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,6.069600636020484608e-01,7.834197017262206630e-01,7.822198087504695918e-01 +0.000000000000000000e+00,1.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,1.000000000000000000e+00,3.582395160335622580e-02,3.253885873016726937e-01,3.368670475454677016e-02 +1.000000000000000000e+00,0.000000000000000000e+00,1.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,1.000000000000000000e+00,8.689867219853986136e-01,8.797927473596606207e-01,7.009148902026304251e-01 +0.000000000000000000e+00,1.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,1.000000000000000000e+00,1.432957961780101652e-01,2.165803086364734842e-01,7.223183581151090271e-01 +0.000000000000000000e+00,1.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,1.000000000000000000e+00,0.000000000000000000e+00,8.393040294637467424e-02,1.088082890278933845e-01,2.683497316892575757e-01 +0.000000000000000000e+00,1.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,1.000000000000000000e+00,5.834186096912714614e-01,1.088082890278933845e-01,2.414059012181253294e-01 +0.000000000000000000e+00,1.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,9.160696400423672392e-01,4.580310936991596193e-01,3.942338560934098846e-01 +0.000000000000000000e+00,1.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,2.139201681457557624e-01,1.088082890278933845e-01,3.068637183129446777e-01 +0.000000000000000000e+00,1.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,0.000000000000000000e+00,1.000000000000000000e+00,0.000000000000000000e+00,7.379733927937235372e-01,9.886009949367772220e-01,2.080736310319135640e-01 diff --git a/tests/data/inference_recommender/sklearn-model.tar.gz b/tests/data/inference_recommender/sklearn-model.tar.gz new file mode 100644 index 0000000000..305bb31429 Binary files /dev/null and b/tests/data/inference_recommender/sklearn-model.tar.gz differ diff --git a/tests/data/inference_recommender/sklearn-payload.tar.gz b/tests/data/inference_recommender/sklearn-payload.tar.gz new file mode 100644 index 0000000000..a6dd9c7f2b Binary files /dev/null and b/tests/data/inference_recommender/sklearn-payload.tar.gz differ diff --git a/tests/integ/test_inference_recommender.py b/tests/integ/test_inference_recommender.py new file mode 100644 index 0000000000..ab06ec6851 --- /dev/null +++ b/tests/integ/test_inference_recommender.py @@ -0,0 +1,211 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import os + +import pytest + +from sagemaker.sklearn.model import SKLearnModel, SKLearnPredictor +from sagemaker.utils import unique_name_from_base +from tests.integ import DATA_DIR +from tests.integ.timeout import timeout +import pandas as pd +from sagemaker.inference_recommender.inference_recommender_mixin import Phase, ModelLatencyThreshold +from sagemaker.parameter import CategoricalParameter +import logging + +logger = logging.getLogger(__name__) + +# Running integration tests on SKLearn model +IR_DIR = os.path.join(DATA_DIR, "inference_recommender") +IR_SKLEARN_MODEL = os.path.join(IR_DIR, "sklearn-model.tar.gz") +IR_SKLEARN_ENTRY_POINT = os.path.join(IR_DIR, "inference.py") +IR_SKLEARN_PAYLOAD = os.path.join(IR_DIR, "sklearn-payload.tar.gz") +IR_SKLEARN_DATA = os.path.join(IR_DIR, "sample.csv") +IR_SKLEARN_CONTENT_TYPE = ["text/csv"] +IR_SKLEARN_FRAMEWORK = "SAGEMAKER-SCIKIT-LEARN" +IR_SKLEARN_FRAMEWORK_VERSION = "1.0-1" + + +@pytest.fixture(scope="module") +def default_right_sized_model(sagemaker_session, cpu_instance_type): + with timeout(minutes=45): + try: + model_package_group_name = unique_name_from_base("test-ir-right-size-model-pkg-sklearn") + model_data = sagemaker_session.upload_data(path=IR_SKLEARN_MODEL) + payload_data = sagemaker_session.upload_data(path=IR_SKLEARN_PAYLOAD) + + iam_client = sagemaker_session.boto_session.client("iam") + role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"] + + sklearn_model = SKLearnModel( + model_data=model_data, + role=role_arn, + entry_point=IR_SKLEARN_ENTRY_POINT, + framework_version=IR_SKLEARN_FRAMEWORK_VERSION, + ) + + sklearn_model_package = sklearn_model.register( + content_types=IR_SKLEARN_CONTENT_TYPE, + response_types=IR_SKLEARN_CONTENT_TYPE, + model_package_group_name=model_package_group_name, + image_uri=sklearn_model.image_uri, + approval_status="Approved", + ) + + return ( + sklearn_model_package.right_size( + sample_payload_url=payload_data, + supported_content_types=IR_SKLEARN_CONTENT_TYPE, + supported_instance_types=[cpu_instance_type], + framework=IR_SKLEARN_FRAMEWORK, + log_level="Quiet", + ), + model_package_group_name, + ) + except Exception: + sagemaker_session.sagemaker_client.delete_model_package( + ModelPackageName=sklearn_model_package.model_package_arn + ) + sagemaker_session.sagemaker_client.delete_model_package_group( + ModelPackageGroupName=model_package_group_name + ) + + +@pytest.fixture(scope="module") +def advanced_right_sized_model(sagemaker_session, cpu_instance_type): + with timeout(minutes=45): + try: + model_package_group_name = unique_name_from_base("test-ir-right-size-model-pkg-sklearn") + model_data = sagemaker_session.upload_data(path=IR_SKLEARN_MODEL) + payload_data = sagemaker_session.upload_data(path=IR_SKLEARN_PAYLOAD) + + iam_client = sagemaker_session.boto_session.client("iam") + role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"] + + sklearn_model = SKLearnModel( + model_data=model_data, + role=role_arn, + entry_point=IR_SKLEARN_ENTRY_POINT, + framework_version=IR_SKLEARN_FRAMEWORK_VERSION, + ) + + sklearn_model_package = sklearn_model.register( + content_types=IR_SKLEARN_CONTENT_TYPE, + response_types=IR_SKLEARN_CONTENT_TYPE, + model_package_group_name=model_package_group_name, + image_uri=sklearn_model.image_uri, + approval_status="Approved", + ) + + hyperparameter_ranges = [ + { + "instance_types": CategoricalParameter([cpu_instance_type]), + "TEST_PARAM": CategoricalParameter( + ["TEST_PARAM_VALUE_1", "TEST_PARAM_VALUE_2"] + ), + } + ] + + phases = [ + Phase(duration_in_seconds=300, initial_number_of_users=2, spawn_rate=2), + Phase(duration_in_seconds=300, initial_number_of_users=14, spawn_rate=2), + ] + + model_latency_thresholds = [ + ModelLatencyThreshold(percentile="P95", value_in_milliseconds=100) + ] + + return ( + sklearn_model_package.right_size( + sample_payload_url=payload_data, + supported_content_types=IR_SKLEARN_CONTENT_TYPE, + framework=IR_SKLEARN_FRAMEWORK, + job_duration_in_seconds=3600, + hyperparameter_ranges=hyperparameter_ranges, + phases=phases, + model_latency_thresholds=model_latency_thresholds, + max_invocations=100, + max_tests=5, + max_parallel_tests=5, + ), + model_package_group_name, + ) + except Exception: + sagemaker_session.sagemaker_client.delete_model_package( + ModelPackageName=sklearn_model_package.model_package_arn + ) + sagemaker_session.sagemaker_client.delete_model_package_group( + ModelPackageGroupName=model_package_group_name + ) + + +@pytest.mark.slow_test +def test_default_right_size_and_deploy_registered_model_sklearn( + default_right_sized_model, sagemaker_session +): + endpoint_name = unique_name_from_base("test-ir-right-size-default-sklearn") + + right_size_model_package, model_package_group_name = default_right_sized_model + with timeout(minutes=45): + try: + right_size_model_package.predictor_cls = SKLearnPredictor + predictor = right_size_model_package.deploy(endpoint_name=endpoint_name) + + payload = pd.read_csv(IR_SKLEARN_DATA, header=None) + + inference = predictor.predict(payload) + assert inference is not None + assert 26 == len(inference) + finally: + sagemaker_session.sagemaker_client.delete_model_package( + ModelPackageName=right_size_model_package.model_package_arn + ) + sagemaker_session.sagemaker_client.delete_model_package_group( + ModelPackageGroupName=model_package_group_name + ) + predictor.delete_model() + predictor.delete_endpoint() + + +@pytest.mark.slow_test +def test_advanced_right_size_and_deploy_registered_model_sklearn( + advanced_right_sized_model, sagemaker_session +): + endpoint_name = unique_name_from_base("test-ir-right-size-advanced-sklearn") + + right_size_model_package, model_package_group_name = advanced_right_sized_model + with timeout(minutes=45): + try: + right_size_model_package.predictor_cls = SKLearnPredictor + predictor = right_size_model_package.deploy(endpoint_name=endpoint_name) + + payload = pd.read_csv(IR_SKLEARN_DATA, header=None) + + inference = predictor.predict(payload) + assert inference is not None + assert 26 == len(inference) + finally: + sagemaker_session.sagemaker_client.delete_model_package( + ModelPackageName=right_size_model_package.model_package_arn + ) + sagemaker_session.sagemaker_client.delete_model_package_group( + ModelPackageGroupName=model_package_group_name + ) + predictor.delete_model() + predictor.delete_endpoint() + + +# TODO when we've added support for inference_recommendation_id +# then add tests to test Framework models diff --git a/tests/unit/sagemaker/inference_recommender/__init__.py b/tests/unit/sagemaker/inference_recommender/__init__.py new file mode 100644 index 0000000000..a6987bc6a6 --- /dev/null +++ b/tests/unit/sagemaker/inference_recommender/__init__.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import diff --git a/tests/unit/sagemaker/inference_recommender/test_inference_recommender_mixin.py b/tests/unit/sagemaker/inference_recommender/test_inference_recommender_mixin.py new file mode 100644 index 0000000000..a8aa219dd0 --- /dev/null +++ b/tests/unit/sagemaker/inference_recommender/test_inference_recommender_mixin.py @@ -0,0 +1,491 @@ +from __future__ import absolute_import + +from unittest.mock import patch, MagicMock, ANY + +from sagemaker.model import Model, ModelPackage +from sagemaker.parameter import CategoricalParameter +from sagemaker.inference_recommender.inference_recommender_mixin import ( + Phase, + ModelLatencyThreshold, +) +from sagemaker.async_inference import AsyncInferenceConfig +from sagemaker.serverless import ServerlessInferenceConfig + +import pytest + +REGION = "us-west-2" + +MODEL_NAME = "model-name-for-ir" +MODEL_DATA = "s3://bucket/model.tar.gz" +MODEL_IMAGE = "model-image-for-ir" +MODEL_PACKAGE_ARN = "model-package-for-ir" + +IR_ROLE_ARN = "arn:aws:iam::123456789123:role/service-role/AmazonSageMaker-ExecutionRole-UnitTest" +IR_SAMPLE_PAYLOAD_URL = "s3://sagemaker-us-west-2-123456789123/payload/payload.tar.gz" +IR_SAMPLE_FRAMEWORK = "SAGEMAKER-SCIKIT-LEARN" +IR_SUPPORTED_CONTENT_TYPES = ["text/csv"] +IR_JOB_NAME = "SMPYTHONSDK-1234567891" +IR_SAMPLE_INSTANCE_TYPE = "ml.c5.xlarge" + +IR_SAMPLE_LIST_OF_INSTANCES_HYPERPARAMETER_RANGES = [ + { + "instance_types": CategoricalParameter(["ml.m5.xlarge", "ml.g4dn.xlarge"]), + "TEST_PARAM": CategoricalParameter(["TEST_PARAM_VALUE_1", "TEST_PARAM_VALUE_2"]), + } +] + +IR_SAMPLE_SINGLE_INSTANCES_HYPERPARAMETER_RANGES = [ + { + "instance_types": CategoricalParameter(["ml.m5.xlarge"]), + "TEST_PARAM": CategoricalParameter(["TEST_PARAM_VALUE_1", "TEST_PARAM_VALUE_2"]), + }, + { + "instance_types": CategoricalParameter(["ml.g4dn.xlarge"]), + "TEST_PARAM": CategoricalParameter(["TEST_PARAM_VALUE_1", "TEST_PARAM_VALUE_2"]), + }, +] + +IR_SAMPLE_INVALID_HYPERPARAMETERS_RANGES = [ + { + "TEST_PARAM": CategoricalParameter(["TEST_PARAM_VALUE_1", "TEST_PARAM_VALUE_2"]), + "TEST_PARAM2": CategoricalParameter(["TEST_PARAM_VALUE_1", "TEST_PARAM_VALUE_2"]), + } +] + +IR_SAMPLE_PHASES = [ + Phase(duration_in_seconds=300, initial_number_of_users=2, spawn_rate=2), + Phase(duration_in_seconds=300, initial_number_of_users=14, spawn_rate=2), +] + +IR_SAMPLE_MODEL_LATENCY_THRESHOLDS = [ + ModelLatencyThreshold(percentile="P95", value_in_milliseconds=100) +] + +IR_RIGHT_SIZE_INSTANCE_TYPE = "ml.m5.xlarge" +IR_RIGHT_SIZE_INITIAL_INSTANCE_COUNT = 1 + +IR_SAMPLE_INFERENCE_RESPONSE = { + "JobName": "SMPYTHONSDK-1671044837", + "JobDescription": "#python-sdk-create", + "PlaceHolder": "...", + "InferenceRecommendations": [ + { + "Metrics": {"PlaceHolder": "..."}, + "EndpointConfiguration": { + "EndpointName": "sm-epc-test", + "VariantName": "sm-epc-test", + "InstanceType": IR_RIGHT_SIZE_INSTANCE_TYPE, + "InitialInstanceCount": IR_RIGHT_SIZE_INITIAL_INSTANCE_COUNT, + }, + "ModelConfiguration": {"PlaceHolder": "..."}, + } + ], + "PlaceHolder": "...", +} + +IR_DEPLOY_ENDPOINT_NAME = "ir-endpoint-test" + +IR_SAMPLE_ENDPOINT_CONFIG = [ + { + "EnvironmentParameterRanges": { + "CategoricalParameterRanges": [ + { + "Name": "TEST_PARAM", + "Value": ["TEST_PARAM_VALUE_1", "TEST_PARAM_VALUE_2"], + }, + ], + }, + "InstanceType": "ml.m5.xlarge", + }, + { + "EnvironmentParameterRanges": { + "CategoricalParameterRanges": [ + { + "Name": "TEST_PARAM", + "Value": ["TEST_PARAM_VALUE_1", "TEST_PARAM_VALUE_2"], + }, + ], + }, + "InstanceType": "ml.g4dn.xlarge", + }, +] + +IR_SAMPLE_TRAFFIC_PATTERN = { + "Phases": [ + { + "DurationInSeconds": 300, + "InitialNumberOfUsers": 2, + "SpawnRate": 2, + }, + { + "DurationInSeconds": 300, + "InitialNumberOfUsers": 14, + "SpawnRate": 2, + }, + ], + "TrafficType": "PHASES", +} + +IR_SAMPLE_STOPPING_CONDITIONS = { + "MaxInvocations": 100, + "ModelLatencyThresholds": [ + { + "Percentile": "P95", + "ValueInMilliseconds": 100, + } + ], +} + +IR_SAMPLE_RESOURCE_LIMIT = { + "MaxNumberOfTests": 5, + "MaxParallelOfTests": 5, +} + + +@pytest.fixture() +def sagemaker_session(): + session = MagicMock(boto_region_name=REGION) + + session.create_inference_recommendations_job.return_value = IR_JOB_NAME + session.wait_for_inference_recommendations_job.return_value = IR_SAMPLE_INFERENCE_RESPONSE + + return session + + +@pytest.fixture() +def model_package(sagemaker_session): + return ModelPackage( + role=IR_ROLE_ARN, model_package_arn=MODEL_PACKAGE_ARN, sagemaker_session=sagemaker_session + ) + + +@pytest.fixture() +def model(sagemaker_session): + return Model(MODEL_IMAGE, MODEL_DATA, role=IR_ROLE_ARN, sagemaker_session=sagemaker_session) + + +@pytest.fixture() +def default_right_sized_model(model_package): + return model_package.right_size( + sample_payload_url=IR_SAMPLE_PAYLOAD_URL, + supported_content_types=IR_SUPPORTED_CONTENT_TYPES, + supported_instance_types=[IR_SAMPLE_INSTANCE_TYPE], + job_name=IR_JOB_NAME, + framework=IR_SAMPLE_FRAMEWORK, + ) + + +def test_right_size_default_with_model_package_successful(sagemaker_session, model_package): + inference_recommender_model_pkg = model_package.right_size( + sample_payload_url=IR_SAMPLE_PAYLOAD_URL, + supported_content_types=IR_SUPPORTED_CONTENT_TYPES, + supported_instance_types=[IR_SAMPLE_INSTANCE_TYPE], + job_name=IR_JOB_NAME, + framework=IR_SAMPLE_FRAMEWORK, + ) + + # assert that the create api has been called with default parameters + assert sagemaker_session.create_inference_recommendations_job.called_with( + role=IR_ROLE_ARN, + job_name=IR_JOB_NAME, + job_type="Default", + job_duration_in_seconds=None, + model_package_version_arn=model_package.model_package_arn, + framework=IR_SAMPLE_FRAMEWORK, + framework_version=None, + sample_payload_url=IR_SAMPLE_PAYLOAD_URL, + supported_content_types=IR_SUPPORTED_CONTENT_TYPES, + supported_instance_types=[IR_SAMPLE_INSTANCE_TYPE], + endpoint_configurations=None, + traffic_pattern=None, + stopping_conditions=None, + resource_limit=None, + ) + + assert sagemaker_session.wait_for_inference_recomendations_job.called_with(IR_JOB_NAME) + + # confirm that the IR instance attributes have been set + assert ( + inference_recommender_model_pkg.inference_recommender_job_results + == IR_SAMPLE_INFERENCE_RESPONSE + ) + assert ( + inference_recommender_model_pkg.inference_recommendations + == IR_SAMPLE_INFERENCE_RESPONSE["InferenceRecommendations"] + ) + + # confirm that the returned object of right_size is itself + assert inference_recommender_model_pkg == model_package + + +def test_right_size_advanced_list_instances_model_package_successful( + sagemaker_session, model_package +): + inference_recommender_model_pkg = model_package.right_size( + sample_payload_url=IR_SAMPLE_PAYLOAD_URL, + supported_content_types=IR_SUPPORTED_CONTENT_TYPES, + framework="SAGEMAKER-SCIKIT-LEARN", + job_duration_in_seconds=7200, + hyperparameter_ranges=IR_SAMPLE_LIST_OF_INSTANCES_HYPERPARAMETER_RANGES, + phases=IR_SAMPLE_PHASES, + traffic_type="PHASES", + max_invocations=100, + model_latency_thresholds=IR_SAMPLE_MODEL_LATENCY_THRESHOLDS, + max_tests=5, + max_parallel_tests=5, + ) + + # assert that the create api has been called with advanced parameters + assert sagemaker_session.create_inference_recommendations_job.called_with( + role=IR_ROLE_ARN, + job_name=IR_JOB_NAME, + job_type="Advanced", + job_duration_in_seconds=7200, + model_package_version_arn=model_package.model_package_arn, + framework=IR_SAMPLE_FRAMEWORK, + framework_version=None, + sample_payload_url=IR_SAMPLE_PAYLOAD_URL, + supported_content_types=IR_SUPPORTED_CONTENT_TYPES, + supported_instance_types=[IR_SAMPLE_INSTANCE_TYPE], + endpoint_configurations=IR_SAMPLE_ENDPOINT_CONFIG, + traffic_pattern=IR_SAMPLE_TRAFFIC_PATTERN, + stopping_conditions=IR_SAMPLE_STOPPING_CONDITIONS, + resource_limit=IR_SAMPLE_RESOURCE_LIMIT, + ) + + assert sagemaker_session.wait_for_inference_recomendations_job.called_with(IR_JOB_NAME) + + # confirm that the IR instance attributes have been set + assert ( + inference_recommender_model_pkg.inference_recommender_job_results + == IR_SAMPLE_INFERENCE_RESPONSE + ) + assert ( + inference_recommender_model_pkg.inference_recommendations + == IR_SAMPLE_INFERENCE_RESPONSE["InferenceRecommendations"] + ) + + # confirm that the returned object of right_size is itself + assert inference_recommender_model_pkg == model_package + + +def test_right_size_advanced_single_instances_model_package_successful( + sagemaker_session, model_package +): + model_package.right_size( + sample_payload_url=IR_SAMPLE_PAYLOAD_URL, + supported_content_types=IR_SUPPORTED_CONTENT_TYPES, + framework="SAGEMAKER-SCIKIT-LEARN", + job_duration_in_seconds=7200, + hyperparameter_ranges=IR_SAMPLE_SINGLE_INSTANCES_HYPERPARAMETER_RANGES, + phases=IR_SAMPLE_PHASES, + traffic_type="PHASES", + max_invocations=100, + model_latency_thresholds=IR_SAMPLE_MODEL_LATENCY_THRESHOLDS, + max_tests=5, + max_parallel_tests=5, + ) + + # assert that the create api has been called with advanced parameters + assert sagemaker_session.create_inference_recommendations_job.called_with( + role=IR_ROLE_ARN, + job_name=IR_JOB_NAME, + job_type="Advanced", + job_duration_in_seconds=7200, + model_package_version_arn=model_package.model_package_arn, + framework=IR_SAMPLE_FRAMEWORK, + framework_version=None, + sample_payload_url=IR_SAMPLE_PAYLOAD_URL, + supported_content_types=IR_SUPPORTED_CONTENT_TYPES, + supported_instance_types=[IR_SAMPLE_INSTANCE_TYPE], + endpoint_configurations=IR_SAMPLE_ENDPOINT_CONFIG, + traffic_pattern=IR_SAMPLE_TRAFFIC_PATTERN, + stopping_conditions=IR_SAMPLE_STOPPING_CONDITIONS, + resource_limit=IR_SAMPLE_RESOURCE_LIMIT, + ) + + +def test_right_size_advanced_model_package_partial_params_successful( + sagemaker_session, model_package +): + model_package.right_size( + sample_payload_url=IR_SAMPLE_PAYLOAD_URL, + supported_content_types=IR_SUPPORTED_CONTENT_TYPES, + framework="SAGEMAKER-SCIKIT-LEARN", + job_duration_in_seconds=7200, + hyperparameter_ranges=IR_SAMPLE_SINGLE_INSTANCES_HYPERPARAMETER_RANGES, + phases=IR_SAMPLE_PHASES, + traffic_type="PHASES", + max_invocations=100, + model_latency_thresholds=IR_SAMPLE_MODEL_LATENCY_THRESHOLDS, + ) + + # assert that the create api has been called with advanced parameters + assert sagemaker_session.create_inference_recommendations_job.called_with( + role=IR_ROLE_ARN, + job_name=IR_JOB_NAME, + job_type="Advanced", + job_duration_in_seconds=7200, + model_package_version_arn=model_package.model_package_arn, + framework=IR_SAMPLE_FRAMEWORK, + framework_version=None, + sample_payload_url=IR_SAMPLE_PAYLOAD_URL, + supported_content_types=IR_SUPPORTED_CONTENT_TYPES, + supported_instance_types=[IR_SAMPLE_INSTANCE_TYPE], + endpoint_configurations=IR_SAMPLE_ENDPOINT_CONFIG, + traffic_pattern=IR_SAMPLE_TRAFFIC_PATTERN, + stopping_conditions=IR_SAMPLE_STOPPING_CONDITIONS, + resource_limit=None, + ) + + +def test_right_size_invalid_hyperparameter_ranges(sagemaker_session, model_package): + with pytest.raises( + ValueError, + match="instance_type must be defined as a hyperparameter_range", + ): + model_package.right_size( + sample_payload_url=IR_SAMPLE_PAYLOAD_URL, + supported_content_types=IR_SUPPORTED_CONTENT_TYPES, + framework="SAGEMAKER-SCIKIT-LEARN", + job_duration_in_seconds=7200, + hyperparameter_ranges=IR_SAMPLE_INVALID_HYPERPARAMETERS_RANGES, + phases=IR_SAMPLE_PHASES, + traffic_type="PHASES", + max_invocations=100, + model_latency_thresholds=IR_SAMPLE_MODEL_LATENCY_THRESHOLDS, + max_tests=5, + max_parallel_tests=5, + ) + + +# TODO -> removed once model registry is decoupled +def test_right_size_missing_model_package_arn(sagemaker_session, model): + with pytest.raises( + ValueError, + match="right_size\\(\\) is currently only supported with a registered model", + ): + model.right_size( + sample_payload_url=IR_SAMPLE_PAYLOAD_URL, + supported_content_types=IR_SUPPORTED_CONTENT_TYPES, + supported_instance_types=[IR_SAMPLE_INSTANCE_TYPE], + job_name=IR_JOB_NAME, + framework=IR_SAMPLE_FRAMEWORK, + ) + + +# TODO check our framework mapping when we add in inference_recommendation_id support + + +@patch("sagemaker.production_variant") +@patch("sagemaker.utils.name_from_base", return_value=MODEL_NAME) +def test_deploy_right_size_with_model_package_succeeds( + production_variant, default_right_sized_model +): + default_right_sized_model.deploy(endpoint_name=IR_DEPLOY_ENDPOINT_NAME) + + assert production_variant.called_with( + model_name=MODEL_NAME, + instance_type=IR_RIGHT_SIZE_INSTANCE_TYPE, + initial_instance_count=IR_RIGHT_SIZE_INITIAL_INSTANCE_COUNT, + accelerator_type=None, + serverless_inference_config=None, + volume_size=None, + model_data_download_timeout=None, + container_startup_health_check_timeout=None, + ) + + +@patch("sagemaker.production_variant") +@patch("sagemaker.utils.name_from_base", return_value=MODEL_NAME) +def test_deploy_right_size_with_both_overrides_succeeds( + production_variant, default_right_sized_model +): + default_right_sized_model.deploy( + instance_type="ml.c5.2xlarge", + initial_instance_count=5, + endpoint_name=IR_DEPLOY_ENDPOINT_NAME, + ) + + assert production_variant.called_with( + model_name=MODEL_NAME, + instance_type="ml.c5.2xlarge", + initial_instance_count=5, + accelerator_type=None, + serverless_inference_config=None, + volume_size=None, + model_data_download_timeout=None, + container_startup_health_check_timeout=None, + ) + + +def test_deploy_right_size_instance_type_override_fails(default_right_sized_model): + with pytest.raises( + ValueError, + match="Must specify instance type and instance count unless using serverless inference", + ): + default_right_sized_model.deploy( + instance_type="ml.c5.2xlarge", + endpoint_name=IR_DEPLOY_ENDPOINT_NAME, + ) + + +def test_deploy_right_size_initial_instance_count_override_fails(default_right_sized_model): + with pytest.raises( + ValueError, + match="Must specify instance type and instance count unless using serverless inference", + ): + default_right_sized_model.deploy( + initial_instance_count=2, + endpoint_name=IR_DEPLOY_ENDPOINT_NAME, + ) + + +def test_deploy_right_size_accelerator_type_fails(default_right_sized_model): + with pytest.raises( + ValueError, + match="accelerator_type is not compatible with right_size\\(\\).", + ): + default_right_sized_model.deploy(accelerator_type="ml.eia.medium") + + +@patch("sagemaker.production_variant") +@patch("sagemaker.utils.name_from_base", return_value=MODEL_NAME) +def test_deploy_right_size_serverless_override(production_variant, default_right_sized_model): + serverless_inference_config = ServerlessInferenceConfig() + default_right_sized_model.deploy(serverless_inference_config=serverless_inference_config) + + assert production_variant.called_with( + model_name=MODEL_NAME, + instance_type=None, + initial_instance_count=None, + accelerator_type=None, + serverless_inference_config=serverless_inference_config._to_request_dict, + volume_size=None, + model_data_download_timeout=None, + container_startup_health_check_timeout=None, + ) + + +@patch("sagemaker.utils.name_from_base", return_value=MODEL_NAME) +def test_deploy_right_size_async_override(sagemaker_session, default_right_sized_model): + async_inference_config = AsyncInferenceConfig(output_path="s3://some-path") + default_right_sized_model.deploy( + instance_type="ml.c5.2xlarge", + initial_instance_count=1, + async_inference_config=async_inference_config, + ) + + assert sagemaker_session.endpoint_from_production_variants.called_with( + name=MODEL_NAME, + production_variants=[ANY], + tags=None, + kms_key=None, + wait=None, + data_capture_config_dict=None, + async_inference_config_dict=async_inference_config._to_request_dict, + ) + + +# TODO -> cover inference_recommendation_id cases +# ... diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 119d08cef4..4f951dfcfe 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -23,7 +23,7 @@ from mock import ANY, MagicMock, Mock, patch, call, mock_open import sagemaker -from sagemaker import TrainingInput, Session, get_execution_role +from sagemaker import TrainingInput, Session, get_execution_role, exceptions from sagemaker.async_inference import AsyncInferenceConfig from sagemaker.session import ( _tuning_job_status, @@ -2267,7 +2267,6 @@ def test_train_done_in_progress(sagemaker_session): "GenerateCandidateDefinitionsOnly": False, } - COMPLETE_EXPECTED_AUTO_ML_JOB_ARGS = { "AutoMLJobName": JOB_NAME, "InputDataConfig": [ @@ -2937,3 +2936,335 @@ def test_wait_for_athena_query(query_execution, sagemaker_session): query_execution.return_value = {"QueryExecution": {"Status": {"State": "SUCCEEDED"}}} sagemaker_session.wait_for_athena_query(query_execution_id="query_id") assert query_execution.called_with(query_execution_id="query_id") + + +IR_USER_JOB_NAME = "custom-job-name" +IR_JOB_NAME = "SMPYTHONSDK-sample-unique-uuid" +IR_ADVANCED_JOB = "Advanced" +IR_ROLE_ARN = "arn:aws:iam::123456789123:role/service-role/AmazonSageMaker-ExecutionRole-UnitTest" +IR_SAMPLE_PAYLOAD_URL = "s3://sagemaker-us-west-2-123456789123/payload/payload.tar.gz" +IR_SUPPORTED_CONTENT_TYPES = ["text/csv"] +IR_MODEL_PACKAGE_VERSION_ARN = ( + "arn:aws:sagemaker:us-west-2:123456789123:model-package/unit-test-package-version/1" +) +IR_NEAREST_MODEL_NAME = "xgboost" +IR_SUPPORTED_INSTANCE_TYPES = ["ml.c5.xlarge", "ml.c5.2xlarge"] +IR_FRAMEWORK = "XGBOOST" +IR_FRAMEWORK_VERSION = "1.2.0" +IR_NEAREST_MODEL_NAME = "xgboost" +IR_JOB_DURATION_IN_SECONDS = 7200 +IR_ENDPOINT_CONFIGURATIONS = [ + { + "EnvironmentParameterRanges": { + "CategoricalParameterRanges": [{"Name": "OMP_NUM_THREADS", "Value": ["2", "4", "10"]}] + }, + "InferenceSpecificationName": "unit-test-specification", + "InstanceType": "ml.c5.xlarge", + } +] +IR_TRAFFIC_PATTERN = { + "Phases": [{"DurationInSeconds": 120, "InitialNumberOfUsers": 1, "SpawnRate": 1}], + "TrafficType": "PHASES", +} +IR_STOPPING_CONDITIONS = { + "MaxInvocations": 300, + "ModelLatencyThresholds": [{"Percentile": "P95", "ValueInMilliseconds": 100}], +} +IR_RESOURCE_LIMIT = {"MaxNumberOfTests": 10, "MaxParallelOfTests": 1} + + +def create_inference_recommendations_job_default_happy_response(): + return { + "JobName": IR_USER_JOB_NAME, + "JobType": "Default", + "RoleArn": IR_ROLE_ARN, + "InputConfig": { + "ContainerConfig": { + "Domain": "MACHINE_LEARNING", + "Task": "OTHER", + "Framework": IR_FRAMEWORK, + "PayloadConfig": { + "SamplePayloadUrl": IR_SAMPLE_PAYLOAD_URL, + "SupportedContentTypes": IR_SUPPORTED_CONTENT_TYPES, + }, + "FrameworkVersion": IR_FRAMEWORK_VERSION, + "NearestModelName": IR_NEAREST_MODEL_NAME, + "SupportedInstanceTypes": IR_SUPPORTED_INSTANCE_TYPES, + }, + "ModelPackageVersionArn": IR_MODEL_PACKAGE_VERSION_ARN, + }, + "JobDescription": "#python-sdk-create", + } + + +def create_inference_recommendations_job_advanced_happy_response(): + base_advanced_job_response = create_inference_recommendations_job_default_happy_response() + + base_advanced_job_response["JobName"] = IR_JOB_NAME + base_advanced_job_response["JobType"] = IR_ADVANCED_JOB + base_advanced_job_response["StoppingConditions"] = IR_STOPPING_CONDITIONS + base_advanced_job_response["InputConfig"]["JobDurationInSeconds"] = IR_JOB_DURATION_IN_SECONDS + base_advanced_job_response["InputConfig"]["EndpointConfigurations"] = IR_ENDPOINT_CONFIGURATIONS + base_advanced_job_response["InputConfig"]["TrafficPattern"] = IR_TRAFFIC_PATTERN + base_advanced_job_response["InputConfig"]["ResourceLimit"] = IR_RESOURCE_LIMIT + + return base_advanced_job_response + + +def test_create_inference_recommendations_job_default_happy(sagemaker_session): + job_name = sagemaker_session.create_inference_recommendations_job( + role=IR_ROLE_ARN, + sample_payload_url=IR_SAMPLE_PAYLOAD_URL, + supported_content_types=IR_SUPPORTED_CONTENT_TYPES, + model_package_version_arn=IR_MODEL_PACKAGE_VERSION_ARN, + framework=IR_FRAMEWORK, + framework_version=IR_FRAMEWORK_VERSION, + nearest_model_name=IR_NEAREST_MODEL_NAME, + supported_instance_types=IR_SUPPORTED_INSTANCE_TYPES, + job_name=IR_USER_JOB_NAME, + ) + + sagemaker_session.sagemaker_client.create_inference_recommendations_job.assert_called_with( + **create_inference_recommendations_job_default_happy_response() + ) + + assert IR_USER_JOB_NAME == job_name + + +@patch("uuid.uuid4", MagicMock(return_value="sample-unique-uuid")) +def test_create_inference_recommendations_job_advanced_happy(sagemaker_session): + job_name = sagemaker_session.create_inference_recommendations_job( + role=IR_ROLE_ARN, + sample_payload_url=IR_SAMPLE_PAYLOAD_URL, + supported_content_types=IR_SUPPORTED_CONTENT_TYPES, + model_package_version_arn=IR_MODEL_PACKAGE_VERSION_ARN, + framework=IR_FRAMEWORK, + framework_version=IR_FRAMEWORK_VERSION, + nearest_model_name=IR_NEAREST_MODEL_NAME, + supported_instance_types=IR_SUPPORTED_INSTANCE_TYPES, + endpoint_configurations=IR_ENDPOINT_CONFIGURATIONS, + traffic_pattern=IR_TRAFFIC_PATTERN, + stopping_conditions=IR_STOPPING_CONDITIONS, + resource_limit=IR_RESOURCE_LIMIT, + job_type=IR_ADVANCED_JOB, + job_duration_in_seconds=IR_JOB_DURATION_IN_SECONDS, + ) + + sagemaker_session.sagemaker_client.create_inference_recommendations_job.assert_called_with( + **create_inference_recommendations_job_advanced_happy_response() + ) + + assert IR_JOB_NAME == job_name + + +def test_create_inference_recommendations_job_propogate_validation_exception(sagemaker_session): + validation_exception_message = ( + "Failed to describe model due to validation failure with following error: test_error" + ) + + validation_exception = ClientError( + {"Error": {"Code": "ValidationException", "Message": validation_exception_message}}, + "create_inference_recommendations_job", + ) + + sagemaker_session.sagemaker_client.create_inference_recommendations_job.side_effect = ( + validation_exception + ) + + with pytest.raises(ClientError) as error: + sagemaker_session.create_inference_recommendations_job( + role=IR_ROLE_ARN, + sample_payload_url=IR_SAMPLE_PAYLOAD_URL, + supported_content_types=IR_SUPPORTED_CONTENT_TYPES, + model_package_version_arn=IR_MODEL_PACKAGE_VERSION_ARN, + framework=IR_FRAMEWORK, + framework_version=IR_FRAMEWORK_VERSION, + nearest_model_name=IR_NEAREST_MODEL_NAME, + supported_instance_types=IR_SUPPORTED_INSTANCE_TYPES, + ) + + assert "ValidationException" in str(error) + + +def test_create_inference_recommendations_job_propogate_other_exception(sagemaker_session): + access_denied_exception_message = "Access is not allowed for the caller." + + access_denied_exception = ClientError( + {"Error": {"Code": "AccessDeniedException", "Message": access_denied_exception_message}}, + "create_inference_recommendations_job", + ) + + sagemaker_session.sagemaker_client.create_inference_recommendations_job.side_effect = ( + access_denied_exception + ) + + with pytest.raises(ClientError) as error: + sagemaker_session.create_inference_recommendations_job( + role=IR_ROLE_ARN, + sample_payload_url=IR_SAMPLE_PAYLOAD_URL, + supported_content_types=IR_SUPPORTED_CONTENT_TYPES, + model_package_version_arn=IR_MODEL_PACKAGE_VERSION_ARN, + framework=IR_FRAMEWORK, + framework_version=IR_FRAMEWORK_VERSION, + nearest_model_name=IR_NEAREST_MODEL_NAME, + supported_instance_types=IR_SUPPORTED_INSTANCE_TYPES, + ) + + assert "AccessDeniedException" in str(error) + + +DEFAULT_LOG_EVENTS_INFERENCE_RECOMMENDER = [ + MockBotoException("ResourceNotFoundException"), + {"nextForwardToken": None, "events": [{"timestamp": 1, "message": "hi there #1"}]}, + {"nextForwardToken": None, "events": [{"timestamp": 2, "message": "hi there #2"}]}, + {"nextForwardToken": None, "events": [{"timestamp": 3, "message": "hi there #3"}]}, + {"nextForwardToken": None, "events": [{"timestamp": 4, "message": "hi there #4"}]}, +] + +FLUSH_LOG_EVENTS_INFERENCE_RECOMMENDER = [ + MockBotoException("ResourceNotFoundException"), + {"nextForwardToken": None, "events": [{"timestamp": 1, "message": "hi there #1"}]}, + {"nextForwardToken": None, "events": [{"timestamp": 2, "message": "hi there #2"}]}, + {"nextForwardToken": None, "events": []}, + {"nextForwardToken": None, "events": [{"timestamp": 3, "message": "hi there #3"}]}, + {"nextForwardToken": None, "events": []}, + {"nextForwardToken": None, "events": [{"timestamp": 4, "message": "hi there #4"}]}, +] + +INFERENCE_RECOMMENDATIONS_DESC_STATUS_PENDING = {"Status": "PENDING"} +INFERENCE_RECOMMENDATIONS_DESC_STATUS_IN_PROGRESS = {"Status": "IN_PROGRESS"} +INFERENCE_RECOMMENDATIONS_DESC_STATUS_COMPLETED = {"Status": "COMPLETED"} + + +@pytest.fixture() +def sm_session_inference_recommender(): + boto_mock = MagicMock(name="boto_session") + boto_mock.client("logs").get_log_events.side_effect = DEFAULT_LOG_EVENTS_INFERENCE_RECOMMENDER + + ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=MagicMock()) + + ims.sagemaker_client.describe_inference_recommendations_job.side_effect = [ + INFERENCE_RECOMMENDATIONS_DESC_STATUS_PENDING, + INFERENCE_RECOMMENDATIONS_DESC_STATUS_IN_PROGRESS, + INFERENCE_RECOMMENDATIONS_DESC_STATUS_COMPLETED, + INFERENCE_RECOMMENDATIONS_DESC_STATUS_COMPLETED, + ] + + return ims + + +@pytest.fixture() +def sm_session_inference_recommender_flush(): + boto_mock = MagicMock(name="boto_session") + boto_mock.client("logs").get_log_events.side_effect = FLUSH_LOG_EVENTS_INFERENCE_RECOMMENDER + + ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=MagicMock()) + + ims.sagemaker_client.describe_inference_recommendations_job.side_effect = [ + INFERENCE_RECOMMENDATIONS_DESC_STATUS_PENDING, + INFERENCE_RECOMMENDATIONS_DESC_STATUS_IN_PROGRESS, + INFERENCE_RECOMMENDATIONS_DESC_STATUS_IN_PROGRESS, + INFERENCE_RECOMMENDATIONS_DESC_STATUS_COMPLETED, + INFERENCE_RECOMMENDATIONS_DESC_STATUS_COMPLETED, + INFERENCE_RECOMMENDATIONS_DESC_STATUS_COMPLETED, + ] + + return ims + + +@patch("time.sleep") +def test_wait_for_inference_recommendations_job_completed(sleep, sm_session_inference_recommender): + assert ( + sm_session_inference_recommender.wait_for_inference_recommendations_job( + JOB_NAME, log_level="Quiet" + )["Status"] + == "COMPLETED" + ) + + assert ( + 4 + == sm_session_inference_recommender.sagemaker_client.describe_inference_recommendations_job.call_count + ) + assert 2 == sleep.call_count + sleep.assert_has_calls([call(120), call(120)]) + + +def test_wait_for_inference_recommendations_job_failed(sagemaker_session): + inference_recommendations_desc_status_failed = { + "Status": "FAILED", + "FailureReason": "Mock Failure Reason", + } + + sagemaker_session.sagemaker_client.describe_inference_recommendations_job = Mock( + name="describe_inference_recommendations_job", + return_value=inference_recommendations_desc_status_failed, + ) + + with pytest.raises(exceptions.UnexpectedStatusException) as error: + sagemaker_session.wait_for_inference_recommendations_job(JOB_NAME) + + assert "Mock Failure Reason" in str(error) + + +@patch("builtins.print") +@patch("time.sleep") +def test_wait_for_inference_recommendations_job_completed_verbose( + sleep, mock_print, sm_session_inference_recommender +): + assert ( + sm_session_inference_recommender.wait_for_inference_recommendations_job( + JOB_NAME, log_level="Verbose" + )["Status"] + == "COMPLETED" + ) + assert ( + 4 + == sm_session_inference_recommender.sagemaker_client.describe_inference_recommendations_job.call_count + ) + + assert ( + 5 == sm_session_inference_recommender.boto_session.client("logs").get_log_events.call_count + ) + + assert 3 == sleep.call_count + sleep.assert_has_calls([call(10), call(60), call(60)]) + + assert 8 == mock_print.call_count + + +@patch("builtins.print") +@patch("time.sleep") +def test_wait_for_inference_recommendations_job_flush_completed( + sleep, mock_print, sm_session_inference_recommender_flush +): + assert ( + sm_session_inference_recommender_flush.wait_for_inference_recommendations_job( + JOB_NAME, log_level="Verbose" + )["Status"] + == "COMPLETED" + ) + assert ( + 6 + == sm_session_inference_recommender_flush.sagemaker_client.describe_inference_recommendations_job.call_count + ) + + assert ( + 7 + == sm_session_inference_recommender_flush.boto_session.client( + "logs" + ).get_log_events.call_count + ) + + assert 5 == sleep.call_count + sleep.assert_has_calls([call(10), call(60), call(60), call(60), call(60)]) + + assert 8 == mock_print.call_count + + +def test_wait_for_inference_recommendations_job_invalid_log_level(sagemaker_session): + with pytest.raises(ValueError) as error: + sagemaker_session.wait_for_inference_recommendations_job( + JOB_NAME, log_level="invalid_log_level" + ) + + assert "log_level must be either Quiet or Verbose" in str(error)