diff --git a/tests/integ/test_workflow.py b/tests/integ/test_workflow.py index fcfbfe61ba..9db7fc064b 100644 --- a/tests/integ/test_workflow.py +++ b/tests/integ/test_workflow.py @@ -18,6 +18,7 @@ import subprocess import time import uuid +import logging from contextlib import contextmanager import pytest @@ -75,6 +76,7 @@ from sagemaker.feature_store.feature_group import FeatureGroup, FeatureDefinition, FeatureTypeEnum from tests.integ import DATA_DIR from tests.integ.kms_utils import get_or_create_kms_key +from tests.integ.retry import retries def ordered(obj): @@ -1850,47 +1852,57 @@ def test_training_job_with_debugger_and_profiler( sagemaker_session=sagemaker_session, ) - try: - response = pipeline.create(role) - create_arn = response["PipelineArn"] - - execution = pipeline.start() - response = execution.describe() - assert response["PipelineArn"] == create_arn - + for _ in retries( + max_retry_count=5, + exception_message_prefix="Waiting for a successful execution of pipeline", + seconds_to_sleep=10, + ): try: - execution.wait(delay=10, max_attempts=60) - except WaiterError: - pass - execution_steps = execution.list_steps() + response = pipeline.create(role) + create_arn = response["PipelineArn"] - assert len(execution_steps) == 1 - assert execution_steps[0].get("FailureReason", "") == "" - assert execution_steps[0]["StepName"] == "pytorch-train" - assert execution_steps[0]["StepStatus"] == "Succeeded" + execution = pipeline.start() + response = execution.describe() + assert response["PipelineArn"] == create_arn - training_job_arn = execution_steps[0]["Metadata"]["TrainingJob"]["Arn"] - job_description = sagemaker_session.sagemaker_client.describe_training_job( - TrainingJobName=training_job_arn.split("/")[1] - ) + try: + execution.wait(delay=10, max_attempts=60) + except WaiterError: + pass + execution_steps = execution.list_steps() - for index, rule in enumerate(rules): - config = job_description["DebugRuleConfigurations"][index] - assert config["RuleConfigurationName"] == rule.name - assert config["RuleEvaluatorImage"] == rule.image_uri - assert config["VolumeSizeInGB"] == 0 - assert ( - config["RuleParameters"]["rule_to_invoke"] == rule.rule_parameters["rule_to_invoke"] + assert len(execution_steps) == 1 + failure_reason = execution_steps[0].get("FailureReason", "") + if failure_reason != "": + logging.error(f"Pipeline execution failed with error: {failure_reason}.Retrying..") + continue + assert execution_steps[0]["StepName"] == "pytorch-train" + assert execution_steps[0]["StepStatus"] == "Succeeded" + + training_job_arn = execution_steps[0]["Metadata"]["TrainingJob"]["Arn"] + job_description = sagemaker_session.sagemaker_client.describe_training_job( + TrainingJobName=training_job_arn.split("/")[1] ) - assert job_description["DebugHookConfig"] == debugger_hook_config._to_request_dict() - assert job_description["ProfilingStatus"] == "Enabled" - assert job_description["ProfilerConfig"]["ProfilingIntervalInMilliseconds"] == 500 - finally: - try: - pipeline.delete() - except Exception: - pass + for index, rule in enumerate(rules): + config = job_description["DebugRuleConfigurations"][index] + assert config["RuleConfigurationName"] == rule.name + assert config["RuleEvaluatorImage"] == rule.image_uri + assert config["VolumeSizeInGB"] == 0 + assert ( + config["RuleParameters"]["rule_to_invoke"] + == rule.rule_parameters["rule_to_invoke"] + ) + assert job_description["DebugHookConfig"] == debugger_hook_config._to_request_dict() + + assert job_description["ProfilingStatus"] == "Enabled" + assert job_description["ProfilerConfig"]["ProfilingIntervalInMilliseconds"] == 500 + break + finally: + try: + pipeline.delete() + except Exception: + pass def test_two_processing_job_depends_on(