Skip to content

support fail_on_violation flag for check steps #3288

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Sep 7, 2022
5 changes: 5 additions & 0 deletions src/sagemaker/workflow/clarify_check_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def __init__(
clarify_check_config: ClarifyCheckConfig,
check_job_config: CheckJobConfig,
skip_check: Union[bool, PipelineVariable] = False,
fail_on_violation: Union[bool, PipelineVariable] = True,
register_new_baseline: Union[bool, PipelineVariable] = False,
model_package_group_name: Union[str, PipelineVariable] = None,
supplied_baseline_constraints: Union[str, PipelineVariable] = None,
Expand All @@ -169,6 +170,8 @@ def __init__(
check_job_config (CheckJobConfig): A CheckJobConfig instance.
skip_check (bool or PipelineVariable): Whether the check
should be skipped (default: False).
fail_on_violation (bool or PipelineVariable): Whether to fail the step
if violation detected (default: True).
register_new_baseline (bool or PipelineVariable): Whether
the new baseline should be registered (default: False).
model_package_group_name (str or PipelineVariable): The name of a
Expand Down Expand Up @@ -214,6 +217,7 @@ def __init__(
name, display_name, description, StepTypeEnum.CLARIFY_CHECK, depends_on
)
self.skip_check = skip_check
self.fail_on_violation = fail_on_violation
self.register_new_baseline = register_new_baseline
self.clarify_check_config = clarify_check_config
self.check_job_config = check_job_config
Expand Down Expand Up @@ -286,6 +290,7 @@ def to_request(self) -> RequestType:

request_dict["ModelPackageGroupName"] = self.model_package_group_name
request_dict["SkipCheck"] = self.skip_check
request_dict["FailOnViolation"] = self.fail_on_violation
request_dict["RegisterNewBaseline"] = self.register_new_baseline
request_dict["SuppliedBaselineConstraints"] = self.supplied_baseline_constraints
if isinstance(
Expand Down
5 changes: 5 additions & 0 deletions src/sagemaker/workflow/quality_check_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def __init__(
quality_check_config: QualityCheckConfig,
check_job_config: CheckJobConfig,
skip_check: Union[bool, PipelineVariable] = False,
fail_on_violation: Union[bool, PipelineVariable] = True,
register_new_baseline: Union[bool, PipelineVariable] = False,
model_package_group_name: Union[str, PipelineVariable] = None,
supplied_baseline_statistics: Union[str, PipelineVariable] = None,
Expand All @@ -134,6 +135,8 @@ def __init__(
check_job_config (CheckJobConfig): A CheckJobConfig instance.
skip_check (bool or PipelineVariable): Whether the check
should be skipped (default: False).
fail_on_violation (bool or PipelineVariable): Whether to fail the step
if violation detected (default: True).
register_new_baseline (bool or PipelineVariable): Whether
the new baseline should be registered (default: False).
model_package_group_name (str or PipelineVariable): The name of a
Expand Down Expand Up @@ -165,6 +168,7 @@ def __init__(
name, display_name, description, StepTypeEnum.QUALITY_CHECK, depends_on
)
self.skip_check = skip_check
self.fail_on_violation = fail_on_violation
self.register_new_baseline = register_new_baseline
self.check_job_config = check_job_config
self.quality_check_config = quality_check_config
Expand Down Expand Up @@ -257,6 +261,7 @@ def to_request(self) -> RequestType:

request_dict["ModelPackageGroupName"] = self.model_package_group_name
request_dict["SkipCheck"] = self.skip_check
request_dict["FailOnViolation"] = self.fail_on_violation
request_dict["RegisterNewBaseline"] = self.register_new_baseline
request_dict["SuppliedBaselineStatistics"] = self.supplied_baseline_statistics
request_dict["SuppliedBaselineConstraints"] = self.supplied_baseline_constraints
Expand Down
20 changes: 15 additions & 5 deletions tests/integ/sagemaker/workflow/test_clarify_check_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,13 @@ def test_one_step_data_bias_pipeline_happycase(
pass


@pytest.mark.parametrize("fail_on_violation", [None, True, False])
def test_one_step_data_bias_pipeline_constraint_violation(
sagemaker_session,
role,
pipeline_name,
check_job_config,
fail_on_violation,
data_bias_check_config,
supplied_baseline_constraints_uri_param,
):
Expand All @@ -234,6 +236,7 @@ def test_one_step_data_bias_pipeline_constraint_violation(
clarify_check_config=data_bias_check_config,
check_job_config=check_job_config,
skip_check=False,
fail_on_violation=fail_on_violation,
register_new_baseline=False,
supplied_baseline_constraints=supplied_baseline_constraints_uri_param,
)
Expand Down Expand Up @@ -276,12 +279,19 @@ def test_one_step_data_bias_pipeline_constraint_violation(
execution_steps = execution.list_steps()

assert len(execution_steps) == 1
failure_reason = execution_steps[0].get("FailureReason", "")
if _CHECK_FAIL_ERROR_MSG not in failure_reason:
logging.error(f"Pipeline execution failed with error: {failure_reason}. Retrying..")
continue
assert execution_steps[0]["StepName"] == "DataBiasCheckStep"
assert execution_steps[0]["StepStatus"] == "Failed"
failure_reason = execution_steps[0].get("FailureReason", "")

if fail_on_violation is None or fail_on_violation:
if _CHECK_FAIL_ERROR_MSG not in failure_reason:
logging.error(
f"Pipeline execution failed with error: {failure_reason}. Retrying.."
)
continue
assert execution_steps[0]["StepStatus"] == "Failed"
else:
assert _CHECK_FAIL_ERROR_MSG not in failure_reason
assert execution_steps[0]["StepStatus"] == "Succeeded"
break
finally:
try:
Expand Down
22 changes: 16 additions & 6 deletions tests/integ/sagemaker/workflow/test_quality_check_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,13 @@ def test_one_step_data_quality_pipeline_happycase(
pass


@pytest.mark.parametrize("fail_on_violation", [None, True, False])
def test_one_step_data_quality_pipeline_constraint_violation(
sagemaker_session,
role,
pipeline_name,
check_job_config,
fail_on_violation,
supplied_baseline_statistics_uri_param,
supplied_baseline_constraints_uri_param,
data_quality_check_config,
Expand All @@ -234,6 +236,7 @@ def test_one_step_data_quality_pipeline_constraint_violation(
data_quality_check_step = QualityCheckStep(
name="DataQualityCheckStep",
skip_check=False,
fail_on_violation=fail_on_violation,
register_new_baseline=False,
quality_check_config=data_quality_check_config,
check_job_config=check_job_config,
Expand Down Expand Up @@ -274,14 +277,21 @@ def test_one_step_data_quality_pipeline_constraint_violation(
except WaiterError:
pass
execution_steps = execution.list_steps()

assert len(execution_steps) == 1
failure_reason = execution_steps[0].get("FailureReason", "")
if _CHECK_FAIL_ERROR_MSG not in failure_reason:
logging.error(f"Pipeline execution failed with error: {failure_reason}. Retrying..")
continue
assert execution_steps[0]["StepName"] == "DataQualityCheckStep"
assert execution_steps[0]["StepStatus"] == "Failed"

failure_reason = execution_steps[0].get("FailureReason", "")
if fail_on_violation is None or fail_on_violation:
if _CHECK_FAIL_ERROR_MSG not in failure_reason:
logging.error(
f"Pipeline execution failed with error: {failure_reason}. Retrying.."
)
continue
assert execution_steps[0]["StepStatus"] == "Failed"
else:
# fail on violation == false
assert _CHECK_FAIL_ERROR_MSG not in failure_reason
assert execution_steps[0]["StepStatus"] == "Succeeded"
break
finally:
try:
Expand Down
6 changes: 6 additions & 0 deletions tests/unit/sagemaker/workflow/test_clarify_check_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def sagemaker_session(boto_session, client):
"CheckType": "DATA_BIAS",
"ModelPackageGroupName": {"Get": "Parameters.MyModelPackageGroup"},
"SkipCheck": False,
"FailOnViolation": False,
"RegisterNewBaseline": False,
"SuppliedBaselineConstraints": "supplied_baseline_constraints",
"CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"},
Expand Down Expand Up @@ -213,6 +214,7 @@ def sagemaker_session(boto_session, client):
"CheckType": "MODEL_BIAS",
"ModelPackageGroupName": {"Get": "Parameters.MyModelPackageGroup"},
"SkipCheck": False,
"FailOnViolation": True,
"RegisterNewBaseline": False,
"SuppliedBaselineConstraints": "supplied_baseline_constraints",
"ModelName": "model_name",
Expand Down Expand Up @@ -277,6 +279,7 @@ def sagemaker_session(boto_session, client):
"CheckType": "MODEL_EXPLAINABILITY",
"ModelPackageGroupName": {"Get": "Parameters.MyModelPackageGroup"},
"SkipCheck": False,
"FailOnViolation": False,
"RegisterNewBaseline": False,
"SuppliedBaselineConstraints": "supplied_baseline_constraints",
"ModelName": "model_name",
Expand Down Expand Up @@ -365,6 +368,7 @@ def test_data_bias_check_step(
clarify_check_config=data_bias_check_config,
check_job_config=check_job_config,
skip_check=False,
fail_on_violation=False,
register_new_baseline=False,
model_package_group_name=model_package_group_name,
supplied_baseline_constraints="supplied_baseline_constraints",
Expand Down Expand Up @@ -406,6 +410,7 @@ def test_model_bias_check_step(
clarify_check_config=model_bias_check_config,
check_job_config=check_job_config,
skip_check=False,
fail_on_violation=True,
register_new_baseline=False,
model_package_group_name=model_package_group_name,
supplied_baseline_constraints="supplied_baseline_constraints",
Expand Down Expand Up @@ -444,6 +449,7 @@ def test_model_explainability_check_step(
clarify_check_config=model_explainability_check_config,
check_job_config=check_job_config,
skip_check=False,
fail_on_violation=False,
register_new_baseline=False,
model_package_group_name=model_package_group_name,
supplied_baseline_constraints="supplied_baseline_constraints",
Expand Down
2 changes: 0 additions & 2 deletions tests/unit/sagemaker/workflow/test_processing_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,7 @@
@pytest.fixture
def client():
"""Mock client.

Considerations when appropriate:

* utilize botocore.stub.Stubber
* separate runtime client from client
"""
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/sagemaker/workflow/test_quality_check_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def sagemaker_session(boto_session, client):
"CheckType": "DATA_QUALITY",
"ModelPackageGroupName": {"Get": "Parameters.MyModelPackageGroup"},
"SkipCheck": False,
"FailOnViolation": False,
"RegisterNewBaseline": False,
"SuppliedBaselineStatistics": {"Get": "Parameters.SuppliedBaselineStatisticsUri"},
"SuppliedBaselineConstraints": {"Get": "Parameters.SuppliedBaselineConstraintsUri"},
Expand Down Expand Up @@ -228,6 +229,7 @@ def sagemaker_session(boto_session, client):
"CheckType": "MODEL_QUALITY",
"ModelPackageGroupName": {"Get": "Parameters.MyModelPackageGroup"},
"SkipCheck": False,
"FailOnViolation": True,
"RegisterNewBaseline": False,
"SuppliedBaselineStatistics": {"Get": "Parameters.SuppliedBaselineStatisticsUri"},
"SuppliedBaselineConstraints": {"Get": "Parameters.SuppliedBaselineConstraintsUri"},
Expand Down Expand Up @@ -278,6 +280,7 @@ def test_data_quality_check_step(
data_quality_check_step = QualityCheckStep(
name="DataQualityCheckStep",
skip_check=False,
fail_on_violation=False,
register_new_baseline=False,
quality_check_config=data_quality_check_config,
check_job_config=check_job_config,
Expand Down Expand Up @@ -324,6 +327,7 @@ def test_model_quality_check_step(
name="ModelQualityCheckStep",
register_new_baseline=False,
skip_check=False,
fail_on_violation=True,
quality_check_config=model_quality_check_config,
check_job_config=check_job_config,
model_package_group_name=model_package_group_name,
Expand Down
2 changes: 0 additions & 2 deletions tests/unit/sagemaker/workflow/test_training_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,7 @@
@pytest.fixture
def client():
"""Mock client.

Considerations when appropriate:

* utilize botocore.stub.Stubber
* separate runtime client from client
"""
Expand Down
2 changes: 0 additions & 2 deletions tests/unit/sagemaker/workflow/test_tuning_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@
@pytest.fixture
def client():
"""Mock client.

Considerations when appropriate:

* utilize botocore.stub.Stubber
* separate runtime client from client
"""
Expand Down