diff --git a/src/sagemaker/workflow/clarify_check_step.py b/src/sagemaker/workflow/clarify_check_step.py index f5c1193be8..9d350b01f3 100644 --- a/src/sagemaker/workflow/clarify_check_step.py +++ b/src/sagemaker/workflow/clarify_check_step.py @@ -153,6 +153,7 @@ def __init__( clarify_check_config: ClarifyCheckConfig, check_job_config: CheckJobConfig, skip_check: Union[bool, PipelineVariable] = False, + fail_on_violation: Union[bool, PipelineVariable] = True, register_new_baseline: Union[bool, PipelineVariable] = False, model_package_group_name: Union[str, PipelineVariable] = None, supplied_baseline_constraints: Union[str, PipelineVariable] = None, @@ -169,6 +170,8 @@ def __init__( check_job_config (CheckJobConfig): A CheckJobConfig instance. skip_check (bool or PipelineVariable): Whether the check should be skipped (default: False). + fail_on_violation (bool or PipelineVariable): Whether to fail the step + if violation detected (default: True). register_new_baseline (bool or PipelineVariable): Whether the new baseline should be registered (default: False). model_package_group_name (str or PipelineVariable): The name of a @@ -214,6 +217,7 @@ def __init__( name, display_name, description, StepTypeEnum.CLARIFY_CHECK, depends_on ) self.skip_check = skip_check + self.fail_on_violation = fail_on_violation self.register_new_baseline = register_new_baseline self.clarify_check_config = clarify_check_config self.check_job_config = check_job_config @@ -286,6 +290,7 @@ def to_request(self) -> RequestType: request_dict["ModelPackageGroupName"] = self.model_package_group_name request_dict["SkipCheck"] = self.skip_check + request_dict["FailOnViolation"] = self.fail_on_violation request_dict["RegisterNewBaseline"] = self.register_new_baseline request_dict["SuppliedBaselineConstraints"] = self.supplied_baseline_constraints if isinstance( diff --git a/src/sagemaker/workflow/quality_check_step.py b/src/sagemaker/workflow/quality_check_step.py index d9d3ea2bef..3a6c3ba627 100644 --- a/src/sagemaker/workflow/quality_check_step.py +++ b/src/sagemaker/workflow/quality_check_step.py @@ -117,6 +117,7 @@ def __init__( quality_check_config: QualityCheckConfig, check_job_config: CheckJobConfig, skip_check: Union[bool, PipelineVariable] = False, + fail_on_violation: Union[bool, PipelineVariable] = True, register_new_baseline: Union[bool, PipelineVariable] = False, model_package_group_name: Union[str, PipelineVariable] = None, supplied_baseline_statistics: Union[str, PipelineVariable] = None, @@ -134,6 +135,8 @@ def __init__( check_job_config (CheckJobConfig): A CheckJobConfig instance. skip_check (bool or PipelineVariable): Whether the check should be skipped (default: False). + fail_on_violation (bool or PipelineVariable): Whether to fail the step + if violation detected (default: True). register_new_baseline (bool or PipelineVariable): Whether the new baseline should be registered (default: False). model_package_group_name (str or PipelineVariable): The name of a @@ -165,6 +168,7 @@ def __init__( name, display_name, description, StepTypeEnum.QUALITY_CHECK, depends_on ) self.skip_check = skip_check + self.fail_on_violation = fail_on_violation self.register_new_baseline = register_new_baseline self.check_job_config = check_job_config self.quality_check_config = quality_check_config @@ -257,6 +261,7 @@ def to_request(self) -> RequestType: request_dict["ModelPackageGroupName"] = self.model_package_group_name request_dict["SkipCheck"] = self.skip_check + request_dict["FailOnViolation"] = self.fail_on_violation request_dict["RegisterNewBaseline"] = self.register_new_baseline request_dict["SuppliedBaselineStatistics"] = self.supplied_baseline_statistics request_dict["SuppliedBaselineConstraints"] = self.supplied_baseline_constraints diff --git a/tests/integ/sagemaker/workflow/test_clarify_check_steps.py b/tests/integ/sagemaker/workflow/test_clarify_check_steps.py index b0d4ac6cbb..aea509b4fb 100644 --- a/tests/integ/sagemaker/workflow/test_clarify_check_steps.py +++ b/tests/integ/sagemaker/workflow/test_clarify_check_steps.py @@ -215,11 +215,13 @@ def test_one_step_data_bias_pipeline_happycase( pass +@pytest.mark.parametrize("fail_on_violation", [None, True, False]) def test_one_step_data_bias_pipeline_constraint_violation( sagemaker_session, role, pipeline_name, check_job_config, + fail_on_violation, data_bias_check_config, supplied_baseline_constraints_uri_param, ): @@ -234,6 +236,7 @@ def test_one_step_data_bias_pipeline_constraint_violation( clarify_check_config=data_bias_check_config, check_job_config=check_job_config, skip_check=False, + fail_on_violation=fail_on_violation, register_new_baseline=False, supplied_baseline_constraints=supplied_baseline_constraints_uri_param, ) @@ -276,12 +279,19 @@ def test_one_step_data_bias_pipeline_constraint_violation( execution_steps = execution.list_steps() assert len(execution_steps) == 1 - failure_reason = execution_steps[0].get("FailureReason", "") - if _CHECK_FAIL_ERROR_MSG not in failure_reason: - logging.error(f"Pipeline execution failed with error: {failure_reason}. Retrying..") - continue assert execution_steps[0]["StepName"] == "DataBiasCheckStep" - assert execution_steps[0]["StepStatus"] == "Failed" + failure_reason = execution_steps[0].get("FailureReason", "") + + if fail_on_violation is None or fail_on_violation: + if _CHECK_FAIL_ERROR_MSG not in failure_reason: + logging.error( + f"Pipeline execution failed with error: {failure_reason}. Retrying.." + ) + continue + assert execution_steps[0]["StepStatus"] == "Failed" + else: + assert _CHECK_FAIL_ERROR_MSG not in failure_reason + assert execution_steps[0]["StepStatus"] == "Succeeded" break finally: try: diff --git a/tests/integ/sagemaker/workflow/test_quality_check_steps.py b/tests/integ/sagemaker/workflow/test_quality_check_steps.py index 043989008e..f521751e47 100644 --- a/tests/integ/sagemaker/workflow/test_quality_check_steps.py +++ b/tests/integ/sagemaker/workflow/test_quality_check_steps.py @@ -215,11 +215,13 @@ def test_one_step_data_quality_pipeline_happycase( pass +@pytest.mark.parametrize("fail_on_violation", [None, True, False]) def test_one_step_data_quality_pipeline_constraint_violation( sagemaker_session, role, pipeline_name, check_job_config, + fail_on_violation, supplied_baseline_statistics_uri_param, supplied_baseline_constraints_uri_param, data_quality_check_config, @@ -234,6 +236,7 @@ def test_one_step_data_quality_pipeline_constraint_violation( data_quality_check_step = QualityCheckStep( name="DataQualityCheckStep", skip_check=False, + fail_on_violation=fail_on_violation, register_new_baseline=False, quality_check_config=data_quality_check_config, check_job_config=check_job_config, @@ -274,14 +277,21 @@ def test_one_step_data_quality_pipeline_constraint_violation( except WaiterError: pass execution_steps = execution.list_steps() - assert len(execution_steps) == 1 - failure_reason = execution_steps[0].get("FailureReason", "") - if _CHECK_FAIL_ERROR_MSG not in failure_reason: - logging.error(f"Pipeline execution failed with error: {failure_reason}. Retrying..") - continue assert execution_steps[0]["StepName"] == "DataQualityCheckStep" - assert execution_steps[0]["StepStatus"] == "Failed" + + failure_reason = execution_steps[0].get("FailureReason", "") + if fail_on_violation is None or fail_on_violation: + if _CHECK_FAIL_ERROR_MSG not in failure_reason: + logging.error( + f"Pipeline execution failed with error: {failure_reason}. Retrying.." + ) + continue + assert execution_steps[0]["StepStatus"] == "Failed" + else: + # fail on violation == false + assert _CHECK_FAIL_ERROR_MSG not in failure_reason + assert execution_steps[0]["StepStatus"] == "Succeeded" break finally: try: diff --git a/tests/unit/sagemaker/workflow/test_clarify_check_step.py b/tests/unit/sagemaker/workflow/test_clarify_check_step.py index 508b4a9379..feadaa03dc 100644 --- a/tests/unit/sagemaker/workflow/test_clarify_check_step.py +++ b/tests/unit/sagemaker/workflow/test_clarify_check_step.py @@ -149,6 +149,7 @@ def sagemaker_session(boto_session, client): "CheckType": "DATA_BIAS", "ModelPackageGroupName": {"Get": "Parameters.MyModelPackageGroup"}, "SkipCheck": False, + "FailOnViolation": False, "RegisterNewBaseline": False, "SuppliedBaselineConstraints": "supplied_baseline_constraints", "CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"}, @@ -213,6 +214,7 @@ def sagemaker_session(boto_session, client): "CheckType": "MODEL_BIAS", "ModelPackageGroupName": {"Get": "Parameters.MyModelPackageGroup"}, "SkipCheck": False, + "FailOnViolation": True, "RegisterNewBaseline": False, "SuppliedBaselineConstraints": "supplied_baseline_constraints", "ModelName": "model_name", @@ -277,6 +279,7 @@ def sagemaker_session(boto_session, client): "CheckType": "MODEL_EXPLAINABILITY", "ModelPackageGroupName": {"Get": "Parameters.MyModelPackageGroup"}, "SkipCheck": False, + "FailOnViolation": False, "RegisterNewBaseline": False, "SuppliedBaselineConstraints": "supplied_baseline_constraints", "ModelName": "model_name", @@ -365,6 +368,7 @@ def test_data_bias_check_step( clarify_check_config=data_bias_check_config, check_job_config=check_job_config, skip_check=False, + fail_on_violation=False, register_new_baseline=False, model_package_group_name=model_package_group_name, supplied_baseline_constraints="supplied_baseline_constraints", @@ -406,6 +410,7 @@ def test_model_bias_check_step( clarify_check_config=model_bias_check_config, check_job_config=check_job_config, skip_check=False, + fail_on_violation=True, register_new_baseline=False, model_package_group_name=model_package_group_name, supplied_baseline_constraints="supplied_baseline_constraints", @@ -444,6 +449,7 @@ def test_model_explainability_check_step( clarify_check_config=model_explainability_check_config, check_job_config=check_job_config, skip_check=False, + fail_on_violation=False, register_new_baseline=False, model_package_group_name=model_package_group_name, supplied_baseline_constraints="supplied_baseline_constraints", diff --git a/tests/unit/sagemaker/workflow/test_processing_step.py b/tests/unit/sagemaker/workflow/test_processing_step.py index 93fd439468..a62b35cd97 100644 --- a/tests/unit/sagemaker/workflow/test_processing_step.py +++ b/tests/unit/sagemaker/workflow/test_processing_step.py @@ -189,9 +189,7 @@ @pytest.fixture def client(): """Mock client. - Considerations when appropriate: - * utilize botocore.stub.Stubber * separate runtime client from client """ diff --git a/tests/unit/sagemaker/workflow/test_quality_check_step.py b/tests/unit/sagemaker/workflow/test_quality_check_step.py index 0e003f00dd..b60e2de8fa 100644 --- a/tests/unit/sagemaker/workflow/test_quality_check_step.py +++ b/tests/unit/sagemaker/workflow/test_quality_check_step.py @@ -154,6 +154,7 @@ def sagemaker_session(boto_session, client): "CheckType": "DATA_QUALITY", "ModelPackageGroupName": {"Get": "Parameters.MyModelPackageGroup"}, "SkipCheck": False, + "FailOnViolation": False, "RegisterNewBaseline": False, "SuppliedBaselineStatistics": {"Get": "Parameters.SuppliedBaselineStatisticsUri"}, "SuppliedBaselineConstraints": {"Get": "Parameters.SuppliedBaselineConstraintsUri"}, @@ -228,6 +229,7 @@ def sagemaker_session(boto_session, client): "CheckType": "MODEL_QUALITY", "ModelPackageGroupName": {"Get": "Parameters.MyModelPackageGroup"}, "SkipCheck": False, + "FailOnViolation": True, "RegisterNewBaseline": False, "SuppliedBaselineStatistics": {"Get": "Parameters.SuppliedBaselineStatisticsUri"}, "SuppliedBaselineConstraints": {"Get": "Parameters.SuppliedBaselineConstraintsUri"}, @@ -278,6 +280,7 @@ def test_data_quality_check_step( data_quality_check_step = QualityCheckStep( name="DataQualityCheckStep", skip_check=False, + fail_on_violation=False, register_new_baseline=False, quality_check_config=data_quality_check_config, check_job_config=check_job_config, @@ -324,6 +327,7 @@ def test_model_quality_check_step( name="ModelQualityCheckStep", register_new_baseline=False, skip_check=False, + fail_on_violation=True, quality_check_config=model_quality_check_config, check_job_config=check_job_config, model_package_group_name=model_package_group_name, diff --git a/tests/unit/sagemaker/workflow/test_training_step.py b/tests/unit/sagemaker/workflow/test_training_step.py index 66a7c2fc43..1f2b80c962 100644 --- a/tests/unit/sagemaker/workflow/test_training_step.py +++ b/tests/unit/sagemaker/workflow/test_training_step.py @@ -162,9 +162,7 @@ @pytest.fixture def client(): """Mock client. - Considerations when appropriate: - * utilize botocore.stub.Stubber * separate runtime client from client """ diff --git a/tests/unit/sagemaker/workflow/test_tuning_step.py b/tests/unit/sagemaker/workflow/test_tuning_step.py index a39512d006..4adccf4b6b 100644 --- a/tests/unit/sagemaker/workflow/test_tuning_step.py +++ b/tests/unit/sagemaker/workflow/test_tuning_step.py @@ -44,9 +44,7 @@ @pytest.fixture def client(): """Mock client. - Considerations when appropriate: - * utilize botocore.stub.Stubber * separate runtime client from client """