Skip to content

Commit 4a6912f

Browse files
authored
Merge branch 'master' into fix/estimator-hyperparameters
2 parents b9d0033 + 93af78b commit 4a6912f

9 files changed

+51
-17
lines changed

src/sagemaker/workflow/clarify_check_step.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ def __init__(
153153
clarify_check_config: ClarifyCheckConfig,
154154
check_job_config: CheckJobConfig,
155155
skip_check: Union[bool, PipelineVariable] = False,
156+
fail_on_violation: Union[bool, PipelineVariable] = True,
156157
register_new_baseline: Union[bool, PipelineVariable] = False,
157158
model_package_group_name: Union[str, PipelineVariable] = None,
158159
supplied_baseline_constraints: Union[str, PipelineVariable] = None,
@@ -169,6 +170,8 @@ def __init__(
169170
check_job_config (CheckJobConfig): A CheckJobConfig instance.
170171
skip_check (bool or PipelineVariable): Whether the check
171172
should be skipped (default: False).
173+
fail_on_violation (bool or PipelineVariable): Whether to fail the step
174+
if violation detected (default: True).
172175
register_new_baseline (bool or PipelineVariable): Whether
173176
the new baseline should be registered (default: False).
174177
model_package_group_name (str or PipelineVariable): The name of a
@@ -214,6 +217,7 @@ def __init__(
214217
name, display_name, description, StepTypeEnum.CLARIFY_CHECK, depends_on
215218
)
216219
self.skip_check = skip_check
220+
self.fail_on_violation = fail_on_violation
217221
self.register_new_baseline = register_new_baseline
218222
self.clarify_check_config = clarify_check_config
219223
self.check_job_config = check_job_config
@@ -286,6 +290,7 @@ def to_request(self) -> RequestType:
286290

287291
request_dict["ModelPackageGroupName"] = self.model_package_group_name
288292
request_dict["SkipCheck"] = self.skip_check
293+
request_dict["FailOnViolation"] = self.fail_on_violation
289294
request_dict["RegisterNewBaseline"] = self.register_new_baseline
290295
request_dict["SuppliedBaselineConstraints"] = self.supplied_baseline_constraints
291296
if isinstance(

src/sagemaker/workflow/quality_check_step.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def __init__(
117117
quality_check_config: QualityCheckConfig,
118118
check_job_config: CheckJobConfig,
119119
skip_check: Union[bool, PipelineVariable] = False,
120+
fail_on_violation: Union[bool, PipelineVariable] = True,
120121
register_new_baseline: Union[bool, PipelineVariable] = False,
121122
model_package_group_name: Union[str, PipelineVariable] = None,
122123
supplied_baseline_statistics: Union[str, PipelineVariable] = None,
@@ -134,6 +135,8 @@ def __init__(
134135
check_job_config (CheckJobConfig): A CheckJobConfig instance.
135136
skip_check (bool or PipelineVariable): Whether the check
136137
should be skipped (default: False).
138+
fail_on_violation (bool or PipelineVariable): Whether to fail the step
139+
if violation detected (default: True).
137140
register_new_baseline (bool or PipelineVariable): Whether
138141
the new baseline should be registered (default: False).
139142
model_package_group_name (str or PipelineVariable): The name of a
@@ -165,6 +168,7 @@ def __init__(
165168
name, display_name, description, StepTypeEnum.QUALITY_CHECK, depends_on
166169
)
167170
self.skip_check = skip_check
171+
self.fail_on_violation = fail_on_violation
168172
self.register_new_baseline = register_new_baseline
169173
self.check_job_config = check_job_config
170174
self.quality_check_config = quality_check_config
@@ -257,6 +261,7 @@ def to_request(self) -> RequestType:
257261

258262
request_dict["ModelPackageGroupName"] = self.model_package_group_name
259263
request_dict["SkipCheck"] = self.skip_check
264+
request_dict["FailOnViolation"] = self.fail_on_violation
260265
request_dict["RegisterNewBaseline"] = self.register_new_baseline
261266
request_dict["SuppliedBaselineStatistics"] = self.supplied_baseline_statistics
262267
request_dict["SuppliedBaselineConstraints"] = self.supplied_baseline_constraints

tests/integ/sagemaker/workflow/test_clarify_check_steps.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -215,11 +215,13 @@ def test_one_step_data_bias_pipeline_happycase(
215215
pass
216216

217217

218+
@pytest.mark.parametrize("fail_on_violation", [None, True, False])
218219
def test_one_step_data_bias_pipeline_constraint_violation(
219220
sagemaker_session,
220221
role,
221222
pipeline_name,
222223
check_job_config,
224+
fail_on_violation,
223225
data_bias_check_config,
224226
supplied_baseline_constraints_uri_param,
225227
):
@@ -234,6 +236,7 @@ def test_one_step_data_bias_pipeline_constraint_violation(
234236
clarify_check_config=data_bias_check_config,
235237
check_job_config=check_job_config,
236238
skip_check=False,
239+
fail_on_violation=fail_on_violation,
237240
register_new_baseline=False,
238241
supplied_baseline_constraints=supplied_baseline_constraints_uri_param,
239242
)
@@ -276,12 +279,19 @@ def test_one_step_data_bias_pipeline_constraint_violation(
276279
execution_steps = execution.list_steps()
277280

278281
assert len(execution_steps) == 1
279-
failure_reason = execution_steps[0].get("FailureReason", "")
280-
if _CHECK_FAIL_ERROR_MSG not in failure_reason:
281-
logging.error(f"Pipeline execution failed with error: {failure_reason}. Retrying..")
282-
continue
283282
assert execution_steps[0]["StepName"] == "DataBiasCheckStep"
284-
assert execution_steps[0]["StepStatus"] == "Failed"
283+
failure_reason = execution_steps[0].get("FailureReason", "")
284+
285+
if fail_on_violation is None or fail_on_violation:
286+
if _CHECK_FAIL_ERROR_MSG not in failure_reason:
287+
logging.error(
288+
f"Pipeline execution failed with error: {failure_reason}. Retrying.."
289+
)
290+
continue
291+
assert execution_steps[0]["StepStatus"] == "Failed"
292+
else:
293+
assert _CHECK_FAIL_ERROR_MSG not in failure_reason
294+
assert execution_steps[0]["StepStatus"] == "Succeeded"
285295
break
286296
finally:
287297
try:

tests/integ/sagemaker/workflow/test_quality_check_steps.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -215,11 +215,13 @@ def test_one_step_data_quality_pipeline_happycase(
215215
pass
216216

217217

218+
@pytest.mark.parametrize("fail_on_violation", [None, True, False])
218219
def test_one_step_data_quality_pipeline_constraint_violation(
219220
sagemaker_session,
220221
role,
221222
pipeline_name,
222223
check_job_config,
224+
fail_on_violation,
223225
supplied_baseline_statistics_uri_param,
224226
supplied_baseline_constraints_uri_param,
225227
data_quality_check_config,
@@ -234,6 +236,7 @@ def test_one_step_data_quality_pipeline_constraint_violation(
234236
data_quality_check_step = QualityCheckStep(
235237
name="DataQualityCheckStep",
236238
skip_check=False,
239+
fail_on_violation=fail_on_violation,
237240
register_new_baseline=False,
238241
quality_check_config=data_quality_check_config,
239242
check_job_config=check_job_config,
@@ -274,14 +277,21 @@ def test_one_step_data_quality_pipeline_constraint_violation(
274277
except WaiterError:
275278
pass
276279
execution_steps = execution.list_steps()
277-
278280
assert len(execution_steps) == 1
279-
failure_reason = execution_steps[0].get("FailureReason", "")
280-
if _CHECK_FAIL_ERROR_MSG not in failure_reason:
281-
logging.error(f"Pipeline execution failed with error: {failure_reason}. Retrying..")
282-
continue
283281
assert execution_steps[0]["StepName"] == "DataQualityCheckStep"
284-
assert execution_steps[0]["StepStatus"] == "Failed"
282+
283+
failure_reason = execution_steps[0].get("FailureReason", "")
284+
if fail_on_violation is None or fail_on_violation:
285+
if _CHECK_FAIL_ERROR_MSG not in failure_reason:
286+
logging.error(
287+
f"Pipeline execution failed with error: {failure_reason}. Retrying.."
288+
)
289+
continue
290+
assert execution_steps[0]["StepStatus"] == "Failed"
291+
else:
292+
# fail on violation == false
293+
assert _CHECK_FAIL_ERROR_MSG not in failure_reason
294+
assert execution_steps[0]["StepStatus"] == "Succeeded"
285295
break
286296
finally:
287297
try:

tests/unit/sagemaker/workflow/test_clarify_check_step.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def sagemaker_session(boto_session, client):
149149
"CheckType": "DATA_BIAS",
150150
"ModelPackageGroupName": {"Get": "Parameters.MyModelPackageGroup"},
151151
"SkipCheck": False,
152+
"FailOnViolation": False,
152153
"RegisterNewBaseline": False,
153154
"SuppliedBaselineConstraints": "supplied_baseline_constraints",
154155
"CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"},
@@ -213,6 +214,7 @@ def sagemaker_session(boto_session, client):
213214
"CheckType": "MODEL_BIAS",
214215
"ModelPackageGroupName": {"Get": "Parameters.MyModelPackageGroup"},
215216
"SkipCheck": False,
217+
"FailOnViolation": True,
216218
"RegisterNewBaseline": False,
217219
"SuppliedBaselineConstraints": "supplied_baseline_constraints",
218220
"ModelName": "model_name",
@@ -277,6 +279,7 @@ def sagemaker_session(boto_session, client):
277279
"CheckType": "MODEL_EXPLAINABILITY",
278280
"ModelPackageGroupName": {"Get": "Parameters.MyModelPackageGroup"},
279281
"SkipCheck": False,
282+
"FailOnViolation": False,
280283
"RegisterNewBaseline": False,
281284
"SuppliedBaselineConstraints": "supplied_baseline_constraints",
282285
"ModelName": "model_name",
@@ -365,6 +368,7 @@ def test_data_bias_check_step(
365368
clarify_check_config=data_bias_check_config,
366369
check_job_config=check_job_config,
367370
skip_check=False,
371+
fail_on_violation=False,
368372
register_new_baseline=False,
369373
model_package_group_name=model_package_group_name,
370374
supplied_baseline_constraints="supplied_baseline_constraints",
@@ -406,6 +410,7 @@ def test_model_bias_check_step(
406410
clarify_check_config=model_bias_check_config,
407411
check_job_config=check_job_config,
408412
skip_check=False,
413+
fail_on_violation=True,
409414
register_new_baseline=False,
410415
model_package_group_name=model_package_group_name,
411416
supplied_baseline_constraints="supplied_baseline_constraints",
@@ -444,6 +449,7 @@ def test_model_explainability_check_step(
444449
clarify_check_config=model_explainability_check_config,
445450
check_job_config=check_job_config,
446451
skip_check=False,
452+
fail_on_violation=False,
447453
register_new_baseline=False,
448454
model_package_group_name=model_package_group_name,
449455
supplied_baseline_constraints="supplied_baseline_constraints",

tests/unit/sagemaker/workflow/test_processing_step.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,9 +189,7 @@
189189
@pytest.fixture
190190
def client():
191191
"""Mock client.
192-
193192
Considerations when appropriate:
194-
195193
* utilize botocore.stub.Stubber
196194
* separate runtime client from client
197195
"""

tests/unit/sagemaker/workflow/test_quality_check_step.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def sagemaker_session(boto_session, client):
154154
"CheckType": "DATA_QUALITY",
155155
"ModelPackageGroupName": {"Get": "Parameters.MyModelPackageGroup"},
156156
"SkipCheck": False,
157+
"FailOnViolation": False,
157158
"RegisterNewBaseline": False,
158159
"SuppliedBaselineStatistics": {"Get": "Parameters.SuppliedBaselineStatisticsUri"},
159160
"SuppliedBaselineConstraints": {"Get": "Parameters.SuppliedBaselineConstraintsUri"},
@@ -228,6 +229,7 @@ def sagemaker_session(boto_session, client):
228229
"CheckType": "MODEL_QUALITY",
229230
"ModelPackageGroupName": {"Get": "Parameters.MyModelPackageGroup"},
230231
"SkipCheck": False,
232+
"FailOnViolation": True,
231233
"RegisterNewBaseline": False,
232234
"SuppliedBaselineStatistics": {"Get": "Parameters.SuppliedBaselineStatisticsUri"},
233235
"SuppliedBaselineConstraints": {"Get": "Parameters.SuppliedBaselineConstraintsUri"},
@@ -278,6 +280,7 @@ def test_data_quality_check_step(
278280
data_quality_check_step = QualityCheckStep(
279281
name="DataQualityCheckStep",
280282
skip_check=False,
283+
fail_on_violation=False,
281284
register_new_baseline=False,
282285
quality_check_config=data_quality_check_config,
283286
check_job_config=check_job_config,
@@ -324,6 +327,7 @@ def test_model_quality_check_step(
324327
name="ModelQualityCheckStep",
325328
register_new_baseline=False,
326329
skip_check=False,
330+
fail_on_violation=True,
327331
quality_check_config=model_quality_check_config,
328332
check_job_config=check_job_config,
329333
model_package_group_name=model_package_group_name,

tests/unit/sagemaker/workflow/test_training_step.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,7 @@
162162
@pytest.fixture
163163
def client():
164164
"""Mock client.
165-
166165
Considerations when appropriate:
167-
168166
* utilize botocore.stub.Stubber
169167
* separate runtime client from client
170168
"""

tests/unit/sagemaker/workflow/test_tuning_step.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,7 @@
4444
@pytest.fixture
4545
def client():
4646
"""Mock client.
47-
4847
Considerations when appropriate:
49-
5048
* utilize botocore.stub.Stubber
5149
* separate runtime client from client
5250
"""

0 commit comments

Comments
 (0)