@@ -99,7 +99,7 @@ def predict(
99
99
self ._input_path = input_path
100
100
response = self ._submit_async_request (input_path , initial_args , inference_id )
101
101
output_location = response ["OutputLocation" ]
102
- failure_location = response [ "FailureLocation" ]
102
+ failure_location = response . get ( "FailureLocation" )
103
103
result = self ._wait_for_output (
104
104
output_path = output_location , failure_path = failure_location , waiter_config = waiter_config
105
105
)
@@ -145,7 +145,7 @@ def predict_async(
145
145
self ._input_path = input_path
146
146
response = self ._submit_async_request (input_path , initial_args , inference_id )
147
147
output_location = response ["OutputLocation" ]
148
- failure_location = response [ "FailureLocation" ]
148
+ failure_location = response . get ( "FailureLocation" )
149
149
response_async = AsyncInferenceResponse (
150
150
predictor_async = self ,
151
151
output_path = output_location ,
@@ -216,6 +216,35 @@ def _submit_async_request(
216
216
return response
217
217
218
218
def _wait_for_output (self , output_path , failure_path , waiter_config ):
219
+ """Retrieve output based on the presense of failure_path."""
220
+ if failure_path is not None :
221
+ return self ._check_output_and_failure_locations (
222
+ output_path , failure_path , waiter_config
223
+ )
224
+
225
+ return self ._check_output_location (output_path , waiter_config )
226
+
227
+ def _check_output_location (self , output_path , waiter_config ):
228
+ """Check the Amazon S3 output path for the output.
229
+
230
+ Periodically check Amazon S3 output path for async inference result.
231
+ Timeout automatically after max attempts reached
232
+ """
233
+ bucket , key = parse_s3_url (output_path )
234
+ s3_waiter = self .s3_client .get_waiter ("object_exists" )
235
+ try :
236
+ s3_waiter .wait (Bucket = bucket , Key = key , WaiterConfig = waiter_config ._to_request_dict ())
237
+ except WaiterError :
238
+ raise PollingTimeoutError (
239
+ message = "Inference could still be running" ,
240
+ output_path = output_path ,
241
+ seconds = waiter_config .delay * waiter_config .max_attempts ,
242
+ )
243
+ s3_object = self .s3_client .get_object (Bucket = bucket , Key = key )
244
+ result = self .predictor ._handle_response (response = s3_object )
245
+ return result
246
+
247
+ def _check_output_and_failure_locations (self , output_path , failure_path , waiter_config ):
219
248
"""Check the Amazon S3 output path for the output.
220
249
221
250
This method waits for either the output file or the failure file to be found on the
0 commit comments