|
12 | 12 | # language governing permissions and limitations under the License.
|
13 | 13 | """Placeholder docstring"""
|
14 | 14 | from __future__ import absolute_import
|
15 |
| - |
| 15 | +import threading |
| 16 | +import time |
16 | 17 | import uuid
|
17 | 18 | from botocore.exceptions import WaiterError
|
18 |
| -from sagemaker.exceptions import PollingTimeoutError |
| 19 | +from sagemaker.exceptions import PollingTimeoutError, AsyncInferenceModelError |
19 | 20 | from sagemaker.async_inference import WaiterConfig, AsyncInferenceResponse
|
20 | 21 | from sagemaker.s3 import parse_s3_url
|
21 | 22 | from sagemaker.session import Session
|
@@ -98,7 +99,10 @@ def predict(
|
98 | 99 | self._input_path = input_path
|
99 | 100 | response = self._submit_async_request(input_path, initial_args, inference_id)
|
100 | 101 | output_location = response["OutputLocation"]
|
101 |
| - result = self._wait_for_output(output_path=output_location, waiter_config=waiter_config) |
| 102 | + failure_location = response["FailureLocation"] |
| 103 | + result = self._wait_for_output( |
| 104 | + output_path=output_location, failure_path=failure_location, waiter_config=waiter_config |
| 105 | + ) |
102 | 106 |
|
103 | 107 | return result
|
104 | 108 |
|
@@ -141,9 +145,11 @@ def predict_async(
|
141 | 145 | self._input_path = input_path
|
142 | 146 | response = self._submit_async_request(input_path, initial_args, inference_id)
|
143 | 147 | output_location = response["OutputLocation"]
|
| 148 | + failure_location = response["FailureLocation"] |
144 | 149 | response_async = AsyncInferenceResponse(
|
145 | 150 | predictor_async=self,
|
146 | 151 | output_path=output_location,
|
| 152 | + failure_path=failure_location, |
147 | 153 | )
|
148 | 154 |
|
149 | 155 | return response_async
|
@@ -209,30 +215,81 @@ def _submit_async_request(
|
209 | 215 |
|
210 | 216 | return response
|
211 | 217 |
|
212 |
| - def _wait_for_output( |
213 |
| - self, |
214 |
| - output_path, |
215 |
| - waiter_config, |
216 |
| - ): |
| 218 | + def _wait_for_output(self, output_path, failure_path, waiter_config): |
217 | 219 | """Check the Amazon S3 output path for the output.
|
218 | 220 |
|
219 |
| - Periodically check Amazon S3 output path for async inference result. |
220 |
| - Timeout automatically after max attempts reached |
221 |
| - """ |
222 |
| - bucket, key = parse_s3_url(output_path) |
223 |
| - s3_waiter = self.s3_client.get_waiter("object_exists") |
224 |
| - try: |
225 |
| - s3_waiter.wait(Bucket=bucket, Key=key, WaiterConfig=waiter_config._to_request_dict()) |
226 |
| - except WaiterError: |
227 |
| - raise PollingTimeoutError( |
228 |
| - message="Inference could still be running", |
229 |
| - output_path=output_path, |
230 |
| - seconds=waiter_config.delay * waiter_config.max_attempts, |
231 |
| - ) |
| 221 | + This method waits for either the output file or the failure file to be found on the |
| 222 | + specified S3 output path. Whichever file is found first, its corresponding event is |
| 223 | + triggered, and the method executes the appropriate action based on the event. |
232 | 224 |
|
233 |
| - s3_object = self.s3_client.get_object(Bucket=bucket, Key=key) |
234 |
| - result = self.predictor._handle_response(response=s3_object) |
235 |
| - return result |
| 225 | + Args: |
| 226 | + output_path (str): The S3 path where the output file is expected to be found. |
| 227 | + failure_path (str): The S3 path where the failure file is expected to be found. |
| 228 | + waiter_config (boto3.waiter.WaiterConfig): The configuration for the S3 waiter. |
| 229 | +
|
| 230 | + Returns: |
| 231 | + Any: The deserialized result from the output file, if the output file is found first. |
| 232 | + Otherwise, raises an exception. |
| 233 | +
|
| 234 | + Raises: |
| 235 | + AsyncInferenceModelError: If the failure file is found before the output file. |
| 236 | + PollingTimeoutError: If both files are not found and the S3 waiter |
| 237 | + has thrown a WaiterError. |
| 238 | + """ |
| 239 | + output_bucket, output_key = parse_s3_url(output_path) |
| 240 | + failure_bucket, failure_key = parse_s3_url(failure_path) |
| 241 | + |
| 242 | + output_file_found = threading.Event() |
| 243 | + failure_file_found = threading.Event() |
| 244 | + |
| 245 | + def check_output_file(): |
| 246 | + try: |
| 247 | + output_file_waiter = self.s3_client.get_waiter("object_exists") |
| 248 | + output_file_waiter.wait( |
| 249 | + Bucket=output_bucket, |
| 250 | + Key=output_key, |
| 251 | + WaiterConfig=waiter_config._to_request_dict(), |
| 252 | + ) |
| 253 | + output_file_found.set() |
| 254 | + except WaiterError: |
| 255 | + pass |
| 256 | + |
| 257 | + def check_failure_file(): |
| 258 | + try: |
| 259 | + failure_file_waiter = self.s3_client.get_waiter("object_exists") |
| 260 | + failure_file_waiter.wait( |
| 261 | + Bucket=failure_bucket, |
| 262 | + Key=failure_key, |
| 263 | + WaiterConfig=waiter_config._to_request_dict(), |
| 264 | + ) |
| 265 | + failure_file_found.set() |
| 266 | + except WaiterError: |
| 267 | + pass |
| 268 | + |
| 269 | + output_thread = threading.Thread(target=check_output_file) |
| 270 | + failure_thread = threading.Thread(target=check_failure_file) |
| 271 | + |
| 272 | + output_thread.start() |
| 273 | + failure_thread.start() |
| 274 | + |
| 275 | + while not output_file_found.is_set() and not failure_file_found.is_set(): |
| 276 | + time.sleep(1) |
| 277 | + |
| 278 | + if output_file_found.is_set(): |
| 279 | + s3_object = self.s3_client.get_object(Bucket=output_bucket, Key=output_key) |
| 280 | + result = self.predictor._handle_response(response=s3_object) |
| 281 | + return result |
| 282 | + |
| 283 | + failure_object = self.s3_client.get_object(Bucket=failure_bucket, Key=failure_key) |
| 284 | + failure_response = self.predictor._handle_response(response=failure_object) |
| 285 | + |
| 286 | + raise AsyncInferenceModelError( |
| 287 | + message=failure_response |
| 288 | + ) if failure_file_found.is_set() else PollingTimeoutError( |
| 289 | + message="Inference could still be running", |
| 290 | + output_path=output_path, |
| 291 | + seconds=waiter_config.delay * waiter_config.max_attempts, |
| 292 | + ) |
236 | 293 |
|
237 | 294 | def update_endpoint(
|
238 | 295 | self,
|
|
0 commit comments