Skip to content

Commit 7818d08

Browse files
nmadanNamrata Madan
authored andcommitted
fix: WaiterError on failed pipeline execution. results() (#1337)
Co-authored-by: Namrata Madan <[email protected]>
1 parent 3cbb262 commit 7818d08

File tree

4 files changed

+32
-12
lines changed

4 files changed

+32
-12
lines changed

src/sagemaker/workflow/pipeline.py

+22-9
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,14 @@
2222
import attr
2323
import botocore
2424
import pytz
25-
from botocore.exceptions import ClientError
25+
from botocore.exceptions import ClientError, WaiterError
2626

2727
from sagemaker import s3
2828
from sagemaker._studio import _append_project_tags
2929
from sagemaker.config import PIPELINE_ROLE_ARN_PATH, PIPELINE_TAGS_PATH
3030
from sagemaker.remote_function.core.serialization import deserialize_obj_from_s3
3131
from sagemaker.remote_function.core.stored_function import RESULTS_FOLDER
32+
from sagemaker.remote_function.errors import RemoteFunctionError
3233
from sagemaker.remote_function.job import JOBS_CONTAINER_ENTRYPOINT
3334
from sagemaker.s3_utils import s3_path_join
3435
from sagemaker.session import Session
@@ -965,8 +966,13 @@ def result(self, step_name: str):
965966
966967
Raises:
967968
ValueError if the provided step is not a ``@step`` decorated function.
969+
RemoteFunctionError if the provided step is not in "Completed" status
968970
"""
969-
self.wait()
971+
try:
972+
self.wait()
973+
except WaiterError as e:
974+
if "Waiter encountered a terminal failure state" in str(e):
975+
pass
970976
step = next(filter(lambda x: x["StepName"] == step_name, self.list_steps()), None)
971977
if not step:
972978
raise ValueError(f"Invalid step name {step_name}")
@@ -986,15 +992,22 @@ def result(self, step_name: str):
986992
]
987993
if container_args != JOBS_CONTAINER_ENTRYPOINT:
988994
raise ValueError(
989-
"This method can only be used on pipeline steps created using @step" " decorator."
995+
"This method can only be used on pipeline steps created using @step decorator."
990996
)
991-
992997
s3_output_path = describe_training_job_response["OutputDataConfig"]["S3OutputPath"]
993-
return deserialize_obj_from_s3(
994-
sagemaker_session=self.sagemaker_session,
995-
s3_uri=s3_path_join(s3_output_path, self.arn.split("/")[-1], step_name, RESULTS_FOLDER),
996-
hmac_key=describe_training_job_response["Environment"]["REMOTE_FUNCTION_SECRET_KEY"],
997-
)
998+
999+
job_status = describe_training_job_response["TrainingJobStatus"]
1000+
if job_status == "Completed":
1001+
return deserialize_obj_from_s3(
1002+
sagemaker_session=self.sagemaker_session,
1003+
s3_uri=s3_path_join(
1004+
s3_output_path, self.arn.split("/")[-1], step_name, RESULTS_FOLDER
1005+
),
1006+
hmac_key=describe_training_job_response["Environment"][
1007+
"REMOTE_FUNCTION_SECRET_KEY"
1008+
],
1009+
)
1010+
raise RemoteFunctionError(f"Pipeline step {step_name} is in {job_status} status.")
9981011

9991012

10001013
class PipelineGraph:

tests/integ/sagemaker/workflow/helpers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def create_and_execute_pipeline(
7171
if step_result_value:
7272
result = execution.result(execution_steps[0]["StepName"])
7373
assert result == step_result_value, f"Expected {step_result_value}, instead found {result}"
74-
return execution_steps
74+
return execution, execution_steps
7575

7676

7777
def validate_scheduled_pipeline_execution(

tests/integ/sagemaker/workflow/test_step_decorator.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from sagemaker import get_execution_role, utils
2323
from sagemaker.config import load_sagemaker_config
2424
from sagemaker.processing import ProcessingInput
25+
from sagemaker.remote_function.errors import RemoteFunctionError
2526
from sagemaker.sklearn import SKLearnProcessor
2627
from sagemaker.remote_function.core.serialization import CloudpickleSerializer
2728
from sagemaker.s3 import S3Uploader
@@ -577,7 +578,7 @@ def divide(x, y):
577578
)
578579

579580
try:
580-
create_and_execute_pipeline(
581+
execution, execution_steps = create_and_execute_pipeline(
581582
pipeline=pipeline,
582583
pipeline_name=pipeline_name,
583584
region_name=region_name,
@@ -587,6 +588,11 @@ def divide(x, y):
587588
execution_parameters=dict(),
588589
step_status="Failed",
589590
)
591+
592+
step_name = execution_steps[0]["StepName"]
593+
with pytest.raises(RemoteFunctionError) as e:
594+
execution.result(step_name)
595+
assert f"Pipeline step {step_name} is in Failed status." in str(e)
590596
finally:
591597
try:
592598
pipeline.delete()
@@ -720,7 +726,7 @@ def func(var: int):
720726
)
721727

722728
try:
723-
execution_steps = create_and_execute_pipeline(
729+
_, execution_steps = create_and_execute_pipeline(
724730
pipeline=pipeline,
725731
pipeline_name=pipeline_name,
726732
region_name=region_name,

tests/unit/sagemaker/workflow/test_pipeline.py

+1
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,7 @@ def test_pipeline_update(
287287
"/opt/ml/input/data/sagemaker_remote_function_bootstrap/job_driver.sh",
288288
]
289289
},
290+
"TrainingJobStatus": "Completed",
290291
"OutputDataConfig": {"S3OutputPath": "s3:/my-bucket/my-key"},
291292
"Environment": {"REMOTE_FUNCTION_SECRET_KEY": "abcdefg"},
292293
}

0 commit comments

Comments
 (0)