Skip to content

Commit 24b7c24

Browse files
author
Keshav Chandak
committed
fix: fixed implementation of fail_on_violation for transform with monitoring
1 parent eb49090 commit 24b7c24

File tree

3 files changed

+127
-3
lines changed

3 files changed

+127
-3
lines changed

src/sagemaker/transformer.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,7 @@ def transform_with_monitoring(
337337
wait: bool = True,
338338
pipeline_name: str = None,
339339
role: str = None,
340+
fail_on_violation: bool = True,
340341
):
341342
"""Runs a transform job with monitoring job.
342343
@@ -352,7 +353,6 @@ def transform_with_monitoring(
352353
]): the monitoring configuration used for run model monitoring.
353354
monitoring_resource_config (`sagemaker.workflow.check_job_config.CheckJobConfig`):
354355
the check job (processing job) cluster resource configuration.
355-
transform_step_args (_JobStepArguments): the transform step transform arguments.
356356
data (str): Input data location in S3 for the transform job
357357
data_type (str): What the S3 location defines (default: 'S3Prefix').
358358
Valid values:
@@ -400,8 +400,6 @@ def transform_with_monitoring(
400400
monitor_before_transform (bgool): If to run data quality
401401
or model explainability monitoring type,
402402
a true value of this flag indicates running the check step before the transform job.
403-
fail_on_violation (Union[bool, PipelineVariable]): A opt-out flag to not to fail the
404-
check step when a violation is detected.
405403
supplied_baseline_statistics (Union[str, PipelineVariable]): The S3 path
406404
to the supplied statistics object representing the statistics JSON file
407405
which will be used for drift to check (default: None).
@@ -411,6 +409,8 @@ def transform_with_monitoring(
411409
wait (bool): To determine if needed to wait for the pipeline execution to complete
412410
pipeline_name (str): The name of the Pipeline for the monitoring and transfrom step
413411
role (str): Execution role
412+
fail_on_violation (Union[bool, PipelineVariable]): A opt-out flag to not to fail the
413+
check step when a violation is detected.
414414
"""
415415

416416
transformer = self
@@ -454,6 +454,7 @@ def transform_with_monitoring(
454454
monitor_before_transform=monitor_before_transform,
455455
supplied_baseline_constraints=supplied_baseline_constraints,
456456
supplied_baseline_statistics=supplied_baseline_statistics,
457+
fail_on_violation=fail_on_violation,
457458
)
458459

459460
pipeline_name = (

tests/integ/test_transformer.py

+64
Original file line numberDiff line numberDiff line change
@@ -709,3 +709,67 @@ def test_transformer_and_monitoring_job(
709709
assert execution_step["StepStatus"] == "Succeeded"
710710

711711
xgb_model.delete_model()
712+
713+
714+
def test_transformer_and_monitoring_job_to_pass_with_no_failure_in_violation(
715+
pipeline_session,
716+
sagemaker_session,
717+
role,
718+
pipeline_name,
719+
check_job_config,
720+
data_bias_check_config,
721+
):
722+
xgb_model_data_s3 = pipeline_session.upload_data(
723+
path=os.path.join(os.path.join(DATA_DIR, "xgboost_abalone"), "xgb_model.tar.gz"),
724+
key_prefix="integ-test-data/xgboost/model",
725+
)
726+
data_bias_supplied_baseline_constraints = Constraints.from_file_path(
727+
constraints_file_path=os.path.join(
728+
DATA_DIR, "pipeline/clarify_check_step/data_bias/bad_cases/analysis.json"
729+
),
730+
sagemaker_session=sagemaker_session,
731+
).file_s3_uri
732+
733+
xgb_model = XGBoostModel(
734+
model_data=xgb_model_data_s3,
735+
framework_version="1.3-1",
736+
role=role,
737+
sagemaker_session=sagemaker_session,
738+
entry_point=os.path.join(os.path.join(DATA_DIR, "xgboost_abalone"), "inference.py"),
739+
enable_network_isolation=True,
740+
)
741+
742+
xgb_model.deploy(_INSTANCE_COUNT, _INSTANCE_TYPE)
743+
744+
transform_output = f"s3://{sagemaker_session.default_bucket()}/{pipeline_name}Transform"
745+
transformer = Transformer(
746+
model_name=xgb_model.name,
747+
strategy="SingleRecord",
748+
instance_type="ml.m5.xlarge",
749+
instance_count=1,
750+
output_path=transform_output,
751+
sagemaker_session=pipeline_session,
752+
)
753+
754+
transform_input = pipeline_session.upload_data(
755+
path=os.path.join(DATA_DIR, "xgboost_abalone", "abalone"),
756+
key_prefix="integ-test-data/xgboost_abalone/abalone",
757+
)
758+
759+
execution = transformer.transform_with_monitoring(
760+
monitoring_config=data_bias_check_config,
761+
monitoring_resource_config=check_job_config,
762+
data=transform_input,
763+
content_type="text/libsvm",
764+
supplied_baseline_constraints=data_bias_supplied_baseline_constraints,
765+
role=role,
766+
fail_on_violation=False,
767+
)
768+
769+
execution_steps = execution.list_steps()
770+
assert len(execution_steps) == 2
771+
772+
for execution_step in execution_steps:
773+
assert execution_step["StepStatus"] == "Succeeded"
774+
775+
xgb_model.delete_model()

tests/unit/test_transformer.py

+59
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,13 @@
2323

2424
from tests.integ import test_local_mode
2525
from tests.unit import SAGEMAKER_CONFIG_TRANSFORM_JOB
26+
from sagemaker.model_monitor import DatasetFormat
27+
from sagemaker.workflow.quality_check_step import (
28+
ModelQualityCheckConfig,
29+
)
30+
from sagemaker.workflow.check_job_config import CheckJobConfig
31+
32+
_CHECK_JOB_PREFIX = "CheckJobPrefix"
2633

2734
ROLE = "DummyRole"
2835
REGION = "us-west-2"
@@ -49,6 +56,16 @@
4956
"base_transform_job_name": JOB_NAME,
5057
}
5158

59+
PROCESS_REQUEST_ARGS = {
60+
"inputs": "processing_inputs",
61+
"output_config": "output_config",
62+
"job_name": "job_name",
63+
"resources": "resource_config",
64+
"stopping_condition": {"MaxRuntimeInSeconds": 3600},
65+
"app_specification": "app_specification",
66+
"experiment_config": {"ExperimentName": "AnExperiment"},
67+
}
68+
5269
MODEL_DESC_PRIMARY_CONTAINER = {"PrimaryContainer": {"Image": IMAGE_URI}}
5370

5471
MODEL_DESC_CONTAINERS_ONLY = {"Containers": [{"Image": IMAGE_URI}]}
@@ -764,6 +781,48 @@ def test_stop_transform_job(sagemaker_session, transformer):
764781
sagemaker_session.stop_transform_job.assert_called_once_with(name=JOB_NAME)
765782

766783

784+
@patch("sagemaker.transformer.Transformer._retrieve_image_uri", return_value=IMAGE_URI)
785+
@patch("sagemaker.workflow.pipeline.Pipeline.upsert", return_value={})
786+
@patch("sagemaker.workflow.pipeline.Pipeline.start", return_value=Mock())
787+
def test_transform_with_monitoring_create_and_starts_pipeline(
788+
pipeline_start, upsert, image_uri, sagemaker_session, transformer
789+
):
790+
791+
config = CheckJobConfig(
792+
role=ROLE,
793+
instance_count=1,
794+
instance_type="ml.m5.xlarge",
795+
volume_size_in_gb=60,
796+
max_runtime_in_seconds=1800,
797+
sagemaker_session=sagemaker_session,
798+
base_job_name=_CHECK_JOB_PREFIX,
799+
)
800+
801+
quality_check_config = ModelQualityCheckConfig(
802+
baseline_dataset="s3://baseline_dataset_s3_url",
803+
dataset_format=DatasetFormat.csv(header=True),
804+
problem_type="BinaryClassification",
805+
inference_attribute="quality_cfg_attr_value",
806+
probability_attribute="quality_cfg_attr_value",
807+
ground_truth_attribute="quality_cfg_attr_value",
808+
probability_threshold_attribute="quality_cfg_attr_value",
809+
post_analytics_processor_script="s3://my_bucket/data_quality/postprocessor.py",
810+
output_s3_uri="s3://output_s3_uri",
811+
)
812+
813+
transformer.transform_with_monitoring(
814+
monitoring_config=quality_check_config,
815+
monitoring_resource_config=config,
816+
data=DATA,
817+
content_type="text/libsvm",
818+
supplied_baseline_constraints="supplied_baseline_constraints",
819+
role=ROLE,
820+
)
821+
822+
upsert.assert_called_once()
823+
pipeline_start.assert_called_once()
824+
825+
767826
def test_stop_transform_job_no_transform_job(transformer):
768827
with pytest.raises(ValueError) as e:
769828
transformer.stop_transform_job()

0 commit comments

Comments
 (0)