diff --git a/doc/overview.rst b/doc/overview.rst index 9cfcfbadaa..7f6490a58b 100644 --- a/doc/overview.rst +++ b/doc/overview.rst @@ -1164,7 +1164,8 @@ More information about SageMaker Asynchronous Inference can be found in the `AWS To deploy asynchronous inference endpoint, you will need to create a ``AsyncInferenceConfig`` object. If you create ``AsyncInferenceConfig`` without specifying its arguments, the default ``S3OutputPath`` will -be ``s3://sagemaker-{REGION}-{ACCOUNTID}/async-endpoint-outputs/{UNIQUE-JOB-NAME}``. (example shown below): +be ``s3://sagemaker-{REGION}-{ACCOUNTID}/async-endpoint-outputs/{UNIQUE-JOB-NAME}``, ``S3FailurePath`` will +be ``s3://sagemaker-{REGION}-{ACCOUNTID}/async-endpoint-failures/{UNIQUE-JOB-NAME}`` (example shown below): .. code:: python @@ -1174,18 +1175,21 @@ be ``s3://sagemaker-{REGION}-{ACCOUNTID}/async-endpoint-outputs/{UNIQUE-JOB-NAME async_config = AsyncInferenceConfig() Or you can specify configurations in ``AsyncInferenceConfig`` as you like. All of those configuration parameters -are optional but if you don’t specify the ``output_path``, Amazon SageMaker will use the default ``S3OutputPath`` +are optional but if you don’t specify the ``output_path`` or ``failure_path``, Amazon SageMaker will use the +default ``S3OutputPath`` or ``S3FailurePath`` mentioned above (example shown below): .. code:: python - # Specify S3OutputPath, MaxConcurrentInvocationsPerInstance and NotificationConfig in the async config object + # Specify S3OutputPath, S3FailurePath, MaxConcurrentInvocationsPerInstance and NotificationConfig + # in the async config object async_config = AsyncInferenceConfig( output_path="s3://{s3_bucket}/{bucket_prefix}/output", max_concurrent_invocations_per_instance=10, notification_config = { "SuccessTopic": "arn:aws:sns:aws-region:account-id:topic-name", "ErrorTopic": "arn:aws:sns:aws-region:account-id:topic-name", + "IncludeInferenceResponseIn": ["SUCCESS_NOTIFICATION_TOPIC","ERROR_NOTIFICATION_TOPIC"], } ) diff --git a/src/sagemaker/async_inference/async_inference_config.py b/src/sagemaker/async_inference/async_inference_config.py index f5e2cb8f57..eb8b2627bb 100644 --- a/src/sagemaker/async_inference/async_inference_config.py +++ b/src/sagemaker/async_inference/async_inference_config.py @@ -31,6 +31,7 @@ def __init__( max_concurrent_invocations_per_instance=None, kms_key_id=None, notification_config=None, + failure_path=None, ): """Initialize an AsyncInferenceConfig object for async inference configuration. @@ -45,6 +46,9 @@ def __init__( kms_key_id (str): Optional. The Amazon Web Services Key Management Service (Amazon Web Services KMS) key that Amazon SageMaker uses to encrypt the asynchronous inference output in Amazon S3. (Default: None) + failure_path (str): Optional. The Amazon S3 location that endpoints upload model + responses for failed requests. If no value is provided, Amazon SageMaker will + use default Amazon S3 Async Inference failure path. (Default: None) notification_config (dict): Optional. Specifies the configuration for notifications of inference results for asynchronous inference. Only one notification is generated per invocation request (Default: None): @@ -54,17 +58,24 @@ def __init__( * error_topic (str): Amazon SNS topic to post a notification to when inference fails. If no topic is provided, no notification is sent on failure. The key in notification_config is 'ErrorTopic'. + * include_inference_response_in (list): Optional. When provided the inference + response will be included in the notification topics. If not provided, + a notification will still be generated on success/error, but will not + contain the inference response. + Valid options are SUCCESS_NOTIFICATION_TOPIC, ERROR_NOTIFICATION_TOPIC """ self.output_path = output_path self.max_concurrent_invocations_per_instance = max_concurrent_invocations_per_instance self.kms_key_id = kms_key_id self.notification_config = notification_config + self.failure_path = failure_path def _to_request_dict(self): """Generates a request dictionary using the parameters provided to the class.""" request_dict = { "OutputConfig": { "S3OutputPath": self.output_path, + "S3FailurePath": self.failure_path, }, } diff --git a/src/sagemaker/async_inference/async_inference_response.py b/src/sagemaker/async_inference/async_inference_response.py index c0e4f8a83d..fb195597c9 100644 --- a/src/sagemaker/async_inference/async_inference_response.py +++ b/src/sagemaker/async_inference/async_inference_response.py @@ -17,7 +17,11 @@ from botocore.exceptions import ClientError from sagemaker.s3 import parse_s3_url from sagemaker.async_inference import WaiterConfig -from sagemaker.exceptions import ObjectNotExistedError, UnexpectedClientError +from sagemaker.exceptions import ( + ObjectNotExistedError, + UnexpectedClientError, + AsyncInferenceModelError, +) class AsyncInferenceResponse(object): @@ -32,6 +36,7 @@ def __init__( self, predictor_async, output_path, + failure_path, ): """Initialize an AsyncInferenceResponse object. @@ -43,10 +48,13 @@ def __init__( that return this response. output_path (str): The Amazon S3 location that endpoints upload inference responses to. + failure_path (str): The Amazon S3 location that endpoints upload model errors + for failed requests. """ self.predictor_async = predictor_async self.output_path = output_path self._result = None + self.failure_path = failure_path def get_result( self, @@ -71,28 +79,34 @@ def get_result( if self._result is None: if waiter_config is None: - self._result = self._get_result_from_s3(self.output_path) + self._result = self._get_result_from_s3(self.output_path, self.failure_path) else: self._result = self.predictor_async._wait_for_output( - self.output_path, waiter_config + self.output_path, self.failure_path, waiter_config ) return self._result - def _get_result_from_s3( - self, - output_path, - ): + def _get_result_from_s3(self, output_path, failure_path): """Get inference result from the output Amazon S3 path""" bucket, key = parse_s3_url(output_path) try: response = self.predictor_async.s3_client.get_object(Bucket=bucket, Key=key) return self.predictor_async.predictor._handle_response(response) - except ClientError as ex: - if ex.response["Error"]["Code"] == "NoSuchKey": - raise ObjectNotExistedError( - message="Inference could still be running", - output_path=output_path, - ) - raise UnexpectedClientError( - message=ex.response["Error"]["Message"], - ) + except ClientError as e: + if e.response["Error"]["Code"] == "NoSuchKey": + try: + failure_bucket, failure_key = parse_s3_url(failure_path) + failure_response = self.predictor_async.s3_client.get_object( + Bucket=failure_bucket, Key=failure_key + ) + failure_response = self.predictor_async.predictor._handle_response( + failure_response + ) + raise AsyncInferenceModelError(message=failure_response) + except ClientError as ex: + if ex.response["Error"]["Code"] == "NoSuchKey": + raise ObjectNotExistedError( + message="Inference could still be running", output_path=output_path + ) + raise UnexpectedClientError(message=ex.response["Error"]["Message"]) + raise UnexpectedClientError(message=e.response["Error"]["Message"]) diff --git a/src/sagemaker/exceptions.py b/src/sagemaker/exceptions.py index 6435615e3e..b9d97cc241 100644 --- a/src/sagemaker/exceptions.py +++ b/src/sagemaker/exceptions.py @@ -77,3 +77,12 @@ def __init__(self, **kwargs): msg = self.fmt.format(**kwargs) Exception.__init__(self, msg) self.kwargs = kwargs + + +class AsyncInferenceModelError(AsyncInferenceError): + """Raised when model returns errors for failed requests""" + + fmt = "Model returned error: {message} " + + def __init__(self, message): + super().__init__(message=message) diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 38286f5205..2d22a45359 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -1277,7 +1277,10 @@ def deploy( async_inference_config_dict = None if is_async: - if async_inference_config.output_path is None: + if ( + async_inference_config.output_path is None + or async_inference_config.failure_path is None + ): async_inference_config = self._build_default_async_inference_config( async_inference_config ) @@ -1316,11 +1319,19 @@ def deploy( def _build_default_async_inference_config(self, async_inference_config): """Build default async inference config and return ``AsyncInferenceConfig``""" - async_output_folder = unique_name_from_base(self.name) - async_output_s3uri = "s3://{}/async-endpoint-outputs/{}".format( - self.sagemaker_session.default_bucket(), async_output_folder - ) - async_inference_config.output_path = async_output_s3uri + unique_folder = unique_name_from_base(self.name) + if async_inference_config.output_path is None: + async_output_s3uri = "s3://{}/async-endpoint-outputs/{}".format( + self.sagemaker_session.default_bucket(), unique_folder + ) + async_inference_config.output_path = async_output_s3uri + + if async_inference_config.failure_path is None: + async_failure_s3uri = "s3://{}/async-endpoint-failures/{}".format( + self.sagemaker_session.default_bucket(), unique_folder + ) + async_inference_config.failure_path = async_failure_s3uri + return async_inference_config def transformer( diff --git a/src/sagemaker/predictor_async.py b/src/sagemaker/predictor_async.py index 39d31c3ee5..2426b86a5c 100644 --- a/src/sagemaker/predictor_async.py +++ b/src/sagemaker/predictor_async.py @@ -12,10 +12,11 @@ # language governing permissions and limitations under the License. """Placeholder docstring""" from __future__ import absolute_import - +import threading +import time import uuid from botocore.exceptions import WaiterError -from sagemaker.exceptions import PollingTimeoutError +from sagemaker.exceptions import PollingTimeoutError, AsyncInferenceModelError from sagemaker.async_inference import WaiterConfig, AsyncInferenceResponse from sagemaker.s3 import parse_s3_url from sagemaker.session import Session @@ -98,7 +99,10 @@ def predict( self._input_path = input_path response = self._submit_async_request(input_path, initial_args, inference_id) output_location = response["OutputLocation"] - result = self._wait_for_output(output_path=output_location, waiter_config=waiter_config) + failure_location = response["FailureLocation"] + result = self._wait_for_output( + output_path=output_location, failure_path=failure_location, waiter_config=waiter_config + ) return result @@ -141,9 +145,11 @@ def predict_async( self._input_path = input_path response = self._submit_async_request(input_path, initial_args, inference_id) output_location = response["OutputLocation"] + failure_location = response["FailureLocation"] response_async = AsyncInferenceResponse( predictor_async=self, output_path=output_location, + failure_path=failure_location, ) return response_async @@ -209,30 +215,81 @@ def _submit_async_request( return response - def _wait_for_output( - self, - output_path, - waiter_config, - ): + def _wait_for_output(self, output_path, failure_path, waiter_config): """Check the Amazon S3 output path for the output. - Periodically check Amazon S3 output path for async inference result. - Timeout automatically after max attempts reached - """ - bucket, key = parse_s3_url(output_path) - s3_waiter = self.s3_client.get_waiter("object_exists") - try: - s3_waiter.wait(Bucket=bucket, Key=key, WaiterConfig=waiter_config._to_request_dict()) - except WaiterError: - raise PollingTimeoutError( - message="Inference could still be running", - output_path=output_path, - seconds=waiter_config.delay * waiter_config.max_attempts, - ) + This method waits for either the output file or the failure file to be found on the + specified S3 output path. Whichever file is found first, its corresponding event is + triggered, and the method executes the appropriate action based on the event. - s3_object = self.s3_client.get_object(Bucket=bucket, Key=key) - result = self.predictor._handle_response(response=s3_object) - return result + Args: + output_path (str): The S3 path where the output file is expected to be found. + failure_path (str): The S3 path where the failure file is expected to be found. + waiter_config (boto3.waiter.WaiterConfig): The configuration for the S3 waiter. + + Returns: + Any: The deserialized result from the output file, if the output file is found first. + Otherwise, raises an exception. + + Raises: + AsyncInferenceModelError: If the failure file is found before the output file. + PollingTimeoutError: If both files are not found and the S3 waiter + has thrown a WaiterError. + """ + output_bucket, output_key = parse_s3_url(output_path) + failure_bucket, failure_key = parse_s3_url(failure_path) + + output_file_found = threading.Event() + failure_file_found = threading.Event() + + def check_output_file(): + try: + output_file_waiter = self.s3_client.get_waiter("object_exists") + output_file_waiter.wait( + Bucket=output_bucket, + Key=output_key, + WaiterConfig=waiter_config._to_request_dict(), + ) + output_file_found.set() + except WaiterError: + pass + + def check_failure_file(): + try: + failure_file_waiter = self.s3_client.get_waiter("object_exists") + failure_file_waiter.wait( + Bucket=failure_bucket, + Key=failure_key, + WaiterConfig=waiter_config._to_request_dict(), + ) + failure_file_found.set() + except WaiterError: + pass + + output_thread = threading.Thread(target=check_output_file) + failure_thread = threading.Thread(target=check_failure_file) + + output_thread.start() + failure_thread.start() + + while not output_file_found.is_set() and not failure_file_found.is_set(): + time.sleep(1) + + if output_file_found.is_set(): + s3_object = self.s3_client.get_object(Bucket=output_bucket, Key=output_key) + result = self.predictor._handle_response(response=s3_object) + return result + + failure_object = self.s3_client.get_object(Bucket=failure_bucket, Key=failure_key) + failure_response = self.predictor._handle_response(response=failure_object) + + raise AsyncInferenceModelError( + message=failure_response + ) if failure_file_found.is_set() else PollingTimeoutError( + message="Inference could still be running", + output_path=output_path, + seconds=waiter_config.delay * waiter_config.max_attempts, + ) def update_endpoint( self, diff --git a/tests/integ/test_async_inference.py b/tests/integ/test_async_inference.py index e4a16da54f..0f7b0c61ff 100644 --- a/tests/integ/test_async_inference.py +++ b/tests/integ/test_async_inference.py @@ -63,6 +63,9 @@ def test_async_walkthrough(sagemaker_session, cpu_instance_type, training_set): assert result_no_wait_with_data.output_path.startswith( "s3://" + sagemaker_session.default_bucket() ) + assert result_no_wait_with_data.failure_path.startswith( + "s3://" + sagemaker_session.default_bucket() + "/async-endpoint-failures/" + ) time.sleep(5) result_no_wait_with_data = result_no_wait_with_data.get_result() assert len(result_no_wait_with_data) == 5 @@ -97,6 +100,9 @@ def test_async_walkthrough(sagemaker_session, cpu_instance_type, training_set): result_not_wait = predictor_async.predict_async(input_path=input_s3_path) assert isinstance(result_not_wait, AsyncInferenceResponse) assert result_not_wait.output_path.startswith("s3://" + sagemaker_session.default_bucket()) + assert result_not_wait.failure_path.startswith( + "s3://" + sagemaker_session.default_bucket() + "/async-endpoint-failures/" + ) time.sleep(5) result_not_wait = result_not_wait.get_result() assert len(result_not_wait) == 5 diff --git a/tests/unit/sagemaker/async_inference/test_async_inference_config.py b/tests/unit/sagemaker/async_inference/test_async_inference_config.py index d941939fa5..1c87476e37 100644 --- a/tests/unit/sagemaker/async_inference/test_async_inference_config.py +++ b/tests/unit/sagemaker/async_inference/test_async_inference_config.py @@ -15,12 +15,14 @@ from sagemaker.async_inference import AsyncInferenceConfig S3_OUTPUT_PATH = "s3://some-output-path" +S3_FAILURE_PATH = "s3://some-failure-path" DEFAULT_KMS_KEY_ID = None DEFAULT_MAX_CONCURRENT_INVOCATIONS = None DEFAULT_NOTIFICATION_CONFIG = None DEFAULT_ASYNC_INFERENCE_DICT = { "OutputConfig": { "S3OutputPath": S3_OUTPUT_PATH, + "S3FailurePath": S3_FAILURE_PATH, }, } @@ -29,10 +31,12 @@ OPTIONAL_NOTIFICATION_CONFIG = { "SuccessTopic": "some-sunccess-topic", "ErrorTopic": "some-error-topic", + "IncludeInferenceResponseIn": ["SUCCESS_NOTIFICATION_TOPIC", "ERROR_NOTIFICATION_TOPIC"], } ASYNC_INFERENCE_DICT_WITH_OPTIONAL = { "OutputConfig": { "S3OutputPath": S3_OUTPUT_PATH, + "S3FailurePath": S3_FAILURE_PATH, "KmsKeyId": OPTIONAL_KMS_KEY_ID, "NotificationConfig": OPTIONAL_NOTIFICATION_CONFIG, }, @@ -41,9 +45,12 @@ def test_init_without_optional(): - async_inference_config = AsyncInferenceConfig(output_path=S3_OUTPUT_PATH) + async_inference_config = AsyncInferenceConfig( + output_path=S3_OUTPUT_PATH, failure_path=S3_FAILURE_PATH + ) assert async_inference_config.output_path == S3_OUTPUT_PATH + assert async_inference_config.failure_path == S3_FAILURE_PATH assert async_inference_config.kms_key_id == DEFAULT_KMS_KEY_ID assert ( async_inference_config.max_concurrent_invocations_per_instance @@ -55,6 +62,7 @@ def test_init_without_optional(): def test_init_with_optional(): async_inference_config = AsyncInferenceConfig( output_path=S3_OUTPUT_PATH, + failure_path=S3_FAILURE_PATH, max_concurrent_invocations_per_instance=OPTIONAL_MAX_CONCURRENT_INVOCATIONS, kms_key_id=OPTIONAL_KMS_KEY_ID, notification_config=OPTIONAL_NOTIFICATION_CONFIG, @@ -62,6 +70,7 @@ def test_init_with_optional(): assert async_inference_config.output_path == S3_OUTPUT_PATH assert async_inference_config.kms_key_id == OPTIONAL_KMS_KEY_ID + assert async_inference_config.failure_path == S3_FAILURE_PATH assert ( async_inference_config.max_concurrent_invocations_per_instance == OPTIONAL_MAX_CONCURRENT_INVOCATIONS @@ -70,11 +79,14 @@ def test_init_with_optional(): def test_to_request_dict(): - async_inference_config = AsyncInferenceConfig(output_path=S3_OUTPUT_PATH) + async_inference_config = AsyncInferenceConfig( + output_path=S3_OUTPUT_PATH, failure_path=S3_FAILURE_PATH + ) assert async_inference_config._to_request_dict() == DEFAULT_ASYNC_INFERENCE_DICT async_inference_config_with_optional = AsyncInferenceConfig( output_path=S3_OUTPUT_PATH, + failure_path=S3_FAILURE_PATH, max_concurrent_invocations_per_instance=OPTIONAL_MAX_CONCURRENT_INVOCATIONS, kms_key_id=OPTIONAL_KMS_KEY_ID, notification_config=OPTIONAL_NOTIFICATION_CONFIG, diff --git a/tests/unit/sagemaker/async_inference/test_async_inference_response.py b/tests/unit/sagemaker/async_inference/test_async_inference_response.py index 8a55dd46fa..a1ad6cf4a8 100644 --- a/tests/unit/sagemaker/async_inference/test_async_inference_response.py +++ b/tests/unit/sagemaker/async_inference/test_async_inference_response.py @@ -18,35 +18,77 @@ from sagemaker.predictor import Predictor from sagemaker.predictor_async import AsyncPredictor from sagemaker.async_inference import AsyncInferenceResponse -from sagemaker.exceptions import ObjectNotExistedError, UnexpectedClientError +from sagemaker.exceptions import ( + AsyncInferenceModelError, + ObjectNotExistedError, + UnexpectedClientError, +) DEFAULT_OUTPUT_PATH = "s3://some-output-path/object-name" +DEFAULT_FAILURE_PATH = "s3://some-failure-path/object-name" ENDPOINT_NAME = "some-endpoint-name" RETURN_VALUE = 0 def empty_s3_client(): + """ + Returns a mocked S3 client with the `get_object` method overridden + to raise different exceptions based on the input. + + Exceptions raised: + - `ClientError` with code "NoSuchKey" + - `AsyncInferenceModelError` + - `ObjectNotExistedError` + - `ClientError` with code "SomeOtherError" + - `UnexpectedClientError` + + """ s3_client = Mock(name="s3-client") - client_other_error = ClientError( - error_response={"Error": {"Code": "SomeOtherError", "Message": "some-error-message"}}, - operation_name="client-other-error", + client_error_no_such_key = ClientError( + error_response={"Error": {"Code": "NoSuchKey"}}, + operation_name="async-inference-response-test", ) - client_error = ClientError( - error_response={"Error": {"Code": "NoSuchKey"}}, + async_error = AsyncInferenceModelError("some error message") + + object_error = ObjectNotExistedError("some error message", DEFAULT_OUTPUT_PATH) + + client_error_other = ClientError( + error_response={"Error": {"Code": "SomeOtherError", "Message": "some error message"}}, operation_name="async-inference-response-test", ) + unexpected_error = UnexpectedClientError("some error message") + + s3_client.get_object = Mock( + name="get_object", + side_effect=[ + client_error_no_such_key, + async_error, + object_error, + client_error_other, + unexpected_error, + ], + ) + return s3_client + + +def mock_s3_client(): + """ + This function returns a mocked S3 client object that has a get_object method with a side_effect + that returns a dictionary with a Body key that points to a mocked response body object. + """ + s3_client = Mock(name="s3-client") response_body = Mock("body") response_body.read = Mock("read", return_value=RETURN_VALUE) response_body.close = Mock("close", return_value=None) - s3_client.get_object = Mock( name="get_object", - side_effect=[client_other_error, client_error, {"Body": response_body}], + side_effect=[ + {"Body": response_body}, + ], ) - return s3_client @@ -61,38 +103,72 @@ def test_init_(): async_inference_response = AsyncInferenceResponse( output_path=DEFAULT_OUTPUT_PATH, predictor_async=predictor_async, + failure_path=DEFAULT_FAILURE_PATH, ) assert async_inference_response.output_path == DEFAULT_OUTPUT_PATH + assert async_inference_response.failure_path == DEFAULT_FAILURE_PATH -def test_get_result(): +def test_wrong_waiter_config_object(): predictor_async = AsyncPredictor(Predictor(ENDPOINT_NAME)) - predictor_async.s3_client = empty_s3_client() async_inference_response = AsyncInferenceResponse( output_path=DEFAULT_OUTPUT_PATH, predictor_async=predictor_async, + failure_path=DEFAULT_FAILURE_PATH, ) - with pytest.raises(UnexpectedClientError): - async_inference_response.get_result() + with pytest.raises( + ValueError, + match="waiter_config should be a WaiterConfig object", + ): + async_inference_response.get_result(waiter_config={}) - with pytest.raises(ObjectNotExistedError, match="Inference could still be running"): - async_inference_response.get_result() + +def test_get_result_success(): + """ + verifies that the result is returned correctly if no errors occur. + """ + # Initialize AsyncInferenceResponse + predictor_async = AsyncPredictor(Predictor(ENDPOINT_NAME)) + predictor_async.s3_client = mock_s3_client() + async_inference_response = AsyncInferenceResponse( + output_path=DEFAULT_OUTPUT_PATH, + predictor_async=predictor_async, + failure_path=DEFAULT_FAILURE_PATH, + ) result = async_inference_response.get_result() assert async_inference_response._result == result assert result == RETURN_VALUE -def test_wrong_waiter_config_object(): +def test_get_result_verify_exceptions(): + """ + Verifies that get_result method raises the expected exception + when an error occurs while fetching the result. + """ + # Initialize AsyncInferenceResponse predictor_async = AsyncPredictor(Predictor(ENDPOINT_NAME)) + predictor_async.s3_client = empty_s3_client() async_inference_response = AsyncInferenceResponse( output_path=DEFAULT_OUTPUT_PATH, predictor_async=predictor_async, + failure_path=DEFAULT_FAILURE_PATH, ) + # Test AsyncInferenceModelError + with pytest.raises(AsyncInferenceModelError, match="Model returned error: some error message"): + async_inference_response.get_result() + + # Test ObjectNotExistedError with pytest.raises( - ValueError, - match="waiter_config should be a WaiterConfig object", + ObjectNotExistedError, + match=f"Object not exist at {DEFAULT_OUTPUT_PATH}. some error message", ): - async_inference_response.get_result(waiter_config={}) + async_inference_response.get_result() + + # Test UnexpectedClientError + with pytest.raises( + UnexpectedClientError, match="Encountered unexpected client error: some error message" + ): + async_inference_response.get_result() diff --git a/tests/unit/sagemaker/inference_recommender/test_inference_recommender_mixin.py b/tests/unit/sagemaker/inference_recommender/test_inference_recommender_mixin.py index 8afd9cd2e0..fc68abc072 100644 --- a/tests/unit/sagemaker/inference_recommender/test_inference_recommender_mixin.py +++ b/tests/unit/sagemaker/inference_recommender/test_inference_recommender_mixin.py @@ -649,7 +649,9 @@ def test_deploy_right_size_serverless_override(sagemaker_session, default_right_ @patch("sagemaker.utils.name_from_base", MagicMock(return_value=MODEL_NAME)) def test_deploy_right_size_async_override(sagemaker_session, default_right_sized_model): default_right_sized_model.name = MODEL_NAME - async_inference_config = AsyncInferenceConfig(output_path="s3://some-path") + async_inference_config = AsyncInferenceConfig( + output_path="s3://some-path", failure_path="s3://some-failure-path" + ) default_right_sized_model.deploy( instance_type="ml.c5.2xlarge", initial_instance_count=1, @@ -663,7 +665,12 @@ def test_deploy_right_size_async_override(sagemaker_session, default_right_sized kms_key=None, wait=True, data_capture_config_dict=None, - async_inference_config_dict={"OutputConfig": {"S3OutputPath": "s3://some-path"}}, + async_inference_config_dict={ + "OutputConfig": { + "S3OutputPath": "s3://some-path", + "S3FailurePath": "s3://some-failure-path", + } + }, explainer_config_dict=None, ) diff --git a/tests/unit/sagemaker/model/test_deploy.py b/tests/unit/sagemaker/model/test_deploy.py index ba28e80251..3d0689fc73 100644 --- a/tests/unit/sagemaker/model/test_deploy.py +++ b/tests/unit/sagemaker/model/test_deploy.py @@ -396,15 +396,19 @@ def test_deploy_wrong_explainer_config(sagemaker_session): @patch("sagemaker.utils.name_from_base", return_value=ENDPOINT_NAME) @patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT) def test_deploy_async_inference(production_variant, name_from_base, sagemaker_session): + S3_OUTPUT_PATH = "s3://some-output-path" + S3_FAILURE_PATH = "s3://some-failure-path" + model = Model( MODEL_IMAGE, MODEL_DATA, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session ) - async_inference_config = AsyncInferenceConfig(output_path="s3://some-path") + async_inference_config = AsyncInferenceConfig( + output_path=S3_OUTPUT_PATH, failure_path=S3_FAILURE_PATH + ) + async_inference_config_dict = { - "OutputConfig": { - "S3OutputPath": "s3://some-path", - }, + "OutputConfig": {"S3OutputPath": S3_OUTPUT_PATH, "S3FailurePath": S3_FAILURE_PATH}, } model.deploy( diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 3e7cbbd7b0..fa51ef6497 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -3269,11 +3269,14 @@ def test_generic_to_deploy_async(sagemaker_session): e.fit() s3_output_path = "s3://some-s3-path" + s3_failure_path = "s3://some-s3-failures-path" predictor_async = e.deploy( INSTANCE_COUNT, INSTANCE_TYPE, - async_inference_config=AsyncInferenceConfig(output_path=s3_output_path), + async_inference_config=AsyncInferenceConfig( + output_path=s3_output_path, failure_path=s3_failure_path + ), ) sagemaker_session.create_model.assert_called_once() diff --git a/tests/unit/test_predictor_async.py b/tests/unit/test_predictor_async.py index cc55cd32ed..f0b69abe93 100644 --- a/tests/unit/test_predictor_async.py +++ b/tests/unit/test_predictor_async.py @@ -14,10 +14,10 @@ import pytest from mock import Mock -from botocore.exceptions import WaiterError +from sagemaker.async_inference.waiter_config import WaiterConfig from sagemaker.predictor import Predictor from sagemaker.predictor_async import AsyncPredictor -from sagemaker.exceptions import PollingTimeoutError +from sagemaker.exceptions import AsyncInferenceModelError, PollingTimeoutError ENDPOINT = "mxnet_endpoint" BUCKET_NAME = "mxnet_endpoint" @@ -29,6 +29,7 @@ PRODUCTION_VARIANT_1 = "PRODUCTION_VARIANT_1" INFERENCE_ID = "inference-id" ASYNC_OUTPUT_LOCATION = "s3://some-output-path/object-name" +ASYNC_FAILURE_LOCATION = "s3://some-failure-path/object-name" ASYNC_INPUT_LOCATION = "s3://some-input-path/object-name" ASYNC_CHECK_PERIOD = 1 ASYNC_PREDICTOR = "async-predictor" @@ -38,6 +39,8 @@ ENDPOINT_CONFIG_DESC = {"ProductionVariants": [{"ModelName": "model-1"}, {"ModelName": "model-2"}]} +DEFAULT_WAITER_CONFIG = WaiterConfig(max_attempts=2, delay=2) # set max_attempts=2 + def empty_sagemaker_session(): ims = Mock(name="sagemaker_session") @@ -50,16 +53,22 @@ def empty_sagemaker_session(): name="invoke_endpoint_async", return_value={ "OutputLocation": ASYNC_OUTPUT_LOCATION, + "FailureLocation": ASYNC_FAILURE_LOCATION, }, ) - response_body = Mock("body") - response_body.read = Mock("read", return_value=RETURN_VALUE) - response_body.close = Mock("close", return_value=None) + + async_inference_model_error = AsyncInferenceModelError(message="some error from model") + + polling_timeout_error = PollingTimeoutError( + message="Inference could still be running", + output_path=ASYNC_OUTPUT_LOCATION, + seconds=DEFAULT_WAITER_CONFIG.delay * DEFAULT_WAITER_CONFIG.max_attempts, + ) ims.s3_client = Mock(name="s3_client") ims.s3_client.get_object = Mock( name="get_object", - return_value={"Body": response_body}, + side_effect=[async_inference_model_error, polling_timeout_error], ) ims.s3_client.put_object = Mock(name="put_object") @@ -100,6 +109,7 @@ def test_async_predict_call_pass_through(): call_args, kwargs = sagemaker_session.sagemaker_runtime_client.invoke_endpoint_async.call_args assert kwargs == expected_request_args assert result.output_path == ASYNC_OUTPUT_LOCATION + assert result.failure_path == ASYNC_FAILURE_LOCATION def test_async_predict_call_with_data(): @@ -148,42 +158,68 @@ def test_async_predict_call_with_data_and_input_path(): call_args, kwargs = sagemaker_session.sagemaker_runtime_client.invoke_endpoint_async.call_args assert kwargs == expected_request_args assert result.output_path == ASYNC_OUTPUT_LOCATION + assert result.failure_path == ASYNC_FAILURE_LOCATION -def test_async_predict_call_pass_through_wait_result(capsys): +def test_async_predict_call_verify_exceptions(): sagemaker_session = empty_sagemaker_session() predictor_async = AsyncPredictor(Predictor(ENDPOINT, sagemaker_session)) - s3_waiter = Mock(name="object_waiter") - waiter_error = WaiterError( - name="async-predictor-unit-test-waiter-error", - reason="test-waiter-error", - last_response="some response", - ) - s3_waiter.wait = Mock(name="wait", side_effect=[waiter_error, None]) - sagemaker_session.s3_client.get_waiter = Mock( - name="object_exists", - return_value=s3_waiter, - ) - input_location = "s3://some-input-path" - with pytest.raises(PollingTimeoutError, match="Inference could still be running"): + with pytest.raises( + AsyncInferenceModelError, match="Model returned error: some error from model" + ): predictor_async.predict(input_path=input_location) - result_async = predictor_async.predict(input_path=input_location) + with pytest.raises( + PollingTimeoutError, + match=f"No result at {ASYNC_OUTPUT_LOCATION} after polling for " + f"{DEFAULT_WAITER_CONFIG.delay*DEFAULT_WAITER_CONFIG.max_attempts}" + f" seconds. Inference could still be running", + ): + predictor_async.predict(input_path=input_location, waiter_config=DEFAULT_WAITER_CONFIG) + assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint_async.called + assert sagemaker_session.s3_client.get_object.called assert sagemaker_session.sagemaker_client.describe_endpoint.not_called assert sagemaker_session.sagemaker_client.describe_endpoint_config.not_called - expected_request_args = { - "Accept": DEFAULT_ACCEPT, - "InputLocation": input_location, - "EndpointName": ENDPOINT, - } - call_args, kwargs = sagemaker_session.sagemaker_runtime_client.invoke_endpoint_async.call_args - assert kwargs == expected_request_args - assert result_async == RETURN_VALUE +def test_async_predict_call_pass_through_success(): + sagemaker_session = empty_sagemaker_session() + + response_body = Mock("body") + response_body.read = Mock("read", return_value=RETURN_VALUE) + response_body.close = Mock("close", return_value=None) + + sagemaker_session.s3_client = Mock(name="s3_client") + sagemaker_session.s3_client.get_object = Mock( + name="get_object", + return_value={"Body": response_body}, + ) + sagemaker_session.s3_client.put_object = Mock(name="put_object") + + predictor_async = AsyncPredictor(Predictor(ENDPOINT, sagemaker_session)) + + sagemaker_session.sagemaker_runtime_client.invoke_endpoint_async = Mock( + name="invoke_endpoint_async", + return_value={ + "OutputLocation": ASYNC_OUTPUT_LOCATION, + "FailureLocation": ASYNC_FAILURE_LOCATION, + }, + ) + + input_location = "s3://some-input-path" + + result = predictor_async.predict( + input_path=input_location, + ) + + assert result == RETURN_VALUE + assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint_async.called + assert sagemaker_session.s3_client.get_waiter.called_with("object_exists") + assert sagemaker_session.sagemaker_client.describe_endpoint.not_called + assert sagemaker_session.sagemaker_client.describe_endpoint_config.not_called def test_predict_async_call_invalid_input(): @@ -223,6 +259,7 @@ def test_predict_call_with_inference_id(): assert kwargs == expected_request_args assert result.output_path == ASYNC_OUTPUT_LOCATION + assert result.failure_path == ASYNC_FAILURE_LOCATION def test_update_endpoint_no_args():