Skip to content

Commit 46b3cbf

Browse files
mufaddal-rohawalaahsan-z-khannavinsoni
authored
fix: Add retries to pipeline execution (#2719)
* fix: Add retries to pipeline execution Co-authored-by: Ahsan Khan <[email protected]> Co-authored-by: Navin Soni <[email protected]>
1 parent 25da5cc commit 46b3cbf

File tree

1 file changed

+47
-35
lines changed

1 file changed

+47
-35
lines changed

tests/integ/test_workflow.py

+47-35
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import subprocess
1919
import time
2020
import uuid
21+
import logging
2122

2223
from contextlib import contextmanager
2324
import pytest
@@ -75,6 +76,7 @@
7576
from sagemaker.feature_store.feature_group import FeatureGroup, FeatureDefinition, FeatureTypeEnum
7677
from tests.integ import DATA_DIR
7778
from tests.integ.kms_utils import get_or_create_kms_key
79+
from tests.integ.retry import retries
7880

7981

8082
def ordered(obj):
@@ -1850,47 +1852,57 @@ def test_training_job_with_debugger_and_profiler(
18501852
sagemaker_session=sagemaker_session,
18511853
)
18521854

1853-
try:
1854-
response = pipeline.create(role)
1855-
create_arn = response["PipelineArn"]
1856-
1857-
execution = pipeline.start()
1858-
response = execution.describe()
1859-
assert response["PipelineArn"] == create_arn
1860-
1855+
for _ in retries(
1856+
max_retry_count=5,
1857+
exception_message_prefix="Waiting for a successful execution of pipeline",
1858+
seconds_to_sleep=10,
1859+
):
18611860
try:
1862-
execution.wait(delay=10, max_attempts=60)
1863-
except WaiterError:
1864-
pass
1865-
execution_steps = execution.list_steps()
1861+
response = pipeline.create(role)
1862+
create_arn = response["PipelineArn"]
18661863

1867-
assert len(execution_steps) == 1
1868-
assert execution_steps[0].get("FailureReason", "") == ""
1869-
assert execution_steps[0]["StepName"] == "pytorch-train"
1870-
assert execution_steps[0]["StepStatus"] == "Succeeded"
1864+
execution = pipeline.start()
1865+
response = execution.describe()
1866+
assert response["PipelineArn"] == create_arn
18711867

1872-
training_job_arn = execution_steps[0]["Metadata"]["TrainingJob"]["Arn"]
1873-
job_description = sagemaker_session.sagemaker_client.describe_training_job(
1874-
TrainingJobName=training_job_arn.split("/")[1]
1875-
)
1868+
try:
1869+
execution.wait(delay=10, max_attempts=60)
1870+
except WaiterError:
1871+
pass
1872+
execution_steps = execution.list_steps()
18761873

1877-
for index, rule in enumerate(rules):
1878-
config = job_description["DebugRuleConfigurations"][index]
1879-
assert config["RuleConfigurationName"] == rule.name
1880-
assert config["RuleEvaluatorImage"] == rule.image_uri
1881-
assert config["VolumeSizeInGB"] == 0
1882-
assert (
1883-
config["RuleParameters"]["rule_to_invoke"] == rule.rule_parameters["rule_to_invoke"]
1874+
assert len(execution_steps) == 1
1875+
failure_reason = execution_steps[0].get("FailureReason", "")
1876+
if failure_reason != "":
1877+
logging.error(f"Pipeline execution failed with error: {failure_reason}.Retrying..")
1878+
continue
1879+
assert execution_steps[0]["StepName"] == "pytorch-train"
1880+
assert execution_steps[0]["StepStatus"] == "Succeeded"
1881+
1882+
training_job_arn = execution_steps[0]["Metadata"]["TrainingJob"]["Arn"]
1883+
job_description = sagemaker_session.sagemaker_client.describe_training_job(
1884+
TrainingJobName=training_job_arn.split("/")[1]
18841885
)
1885-
assert job_description["DebugHookConfig"] == debugger_hook_config._to_request_dict()
18861886

1887-
assert job_description["ProfilingStatus"] == "Enabled"
1888-
assert job_description["ProfilerConfig"]["ProfilingIntervalInMilliseconds"] == 500
1889-
finally:
1890-
try:
1891-
pipeline.delete()
1892-
except Exception:
1893-
pass
1887+
for index, rule in enumerate(rules):
1888+
config = job_description["DebugRuleConfigurations"][index]
1889+
assert config["RuleConfigurationName"] == rule.name
1890+
assert config["RuleEvaluatorImage"] == rule.image_uri
1891+
assert config["VolumeSizeInGB"] == 0
1892+
assert (
1893+
config["RuleParameters"]["rule_to_invoke"]
1894+
== rule.rule_parameters["rule_to_invoke"]
1895+
)
1896+
assert job_description["DebugHookConfig"] == debugger_hook_config._to_request_dict()
1897+
1898+
assert job_description["ProfilingStatus"] == "Enabled"
1899+
assert job_description["ProfilerConfig"]["ProfilingIntervalInMilliseconds"] == 500
1900+
break
1901+
finally:
1902+
try:
1903+
pipeline.delete()
1904+
except Exception:
1905+
pass
18941906

18951907

18961908
def test_two_processing_job_depends_on(

0 commit comments

Comments
 (0)