Skip to content

fix: fixed implementation of fail_on_violation for transform with monitoring #4442

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/sagemaker/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ def transform_with_monitoring(
wait: bool = True,
pipeline_name: str = None,
role: str = None,
fail_on_violation: bool = True,
):
"""Runs a transform job with monitoring job.

Expand All @@ -352,7 +353,6 @@ def transform_with_monitoring(
]): the monitoring configuration used for run model monitoring.
monitoring_resource_config (`sagemaker.workflow.check_job_config.CheckJobConfig`):
the check job (processing job) cluster resource configuration.
transform_step_args (_JobStepArguments): the transform step transform arguments.
data (str): Input data location in S3 for the transform job
data_type (str): What the S3 location defines (default: 'S3Prefix').
Valid values:
Expand Down Expand Up @@ -400,8 +400,6 @@ def transform_with_monitoring(
monitor_before_transform (bgool): If to run data quality
or model explainability monitoring type,
a true value of this flag indicates running the check step before the transform job.
fail_on_violation (Union[bool, PipelineVariable]): A opt-out flag to not to fail the
check step when a violation is detected.
supplied_baseline_statistics (Union[str, PipelineVariable]): The S3 path
to the supplied statistics object representing the statistics JSON file
which will be used for drift to check (default: None).
Expand All @@ -411,6 +409,8 @@ def transform_with_monitoring(
wait (bool): To determine if needed to wait for the pipeline execution to complete
pipeline_name (str): The name of the Pipeline for the monitoring and transfrom step
role (str): Execution role
fail_on_violation (Union[bool, PipelineVariable]): A opt-out flag to not to fail the
check step when a violation is detected.
"""

transformer = self
Expand Down Expand Up @@ -454,6 +454,7 @@ def transform_with_monitoring(
monitor_before_transform=monitor_before_transform,
supplied_baseline_constraints=supplied_baseline_constraints,
supplied_baseline_statistics=supplied_baseline_statistics,
fail_on_violation=fail_on_violation,
)

pipeline_name = (
Expand Down
64 changes: 64 additions & 0 deletions tests/integ/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,3 +709,67 @@ def test_transformer_and_monitoring_job(
assert execution_step["StepStatus"] == "Succeeded"

xgb_model.delete_model()


def test_transformer_and_monitoring_job_to_pass_with_no_failure_in_violation(
pipeline_session,
sagemaker_session,
role,
pipeline_name,
check_job_config,
data_bias_check_config,
):
xgb_model_data_s3 = pipeline_session.upload_data(
path=os.path.join(os.path.join(DATA_DIR, "xgboost_abalone"), "xgb_model.tar.gz"),
key_prefix="integ-test-data/xgboost/model",
)
data_bias_supplied_baseline_constraints = Constraints.from_file_path(
constraints_file_path=os.path.join(
DATA_DIR, "pipeline/clarify_check_step/data_bias/bad_cases/analysis.json"
),
sagemaker_session=sagemaker_session,
).file_s3_uri

xgb_model = XGBoostModel(
model_data=xgb_model_data_s3,
framework_version="1.3-1",
role=role,
sagemaker_session=sagemaker_session,
entry_point=os.path.join(os.path.join(DATA_DIR, "xgboost_abalone"), "inference.py"),
enable_network_isolation=True,
)

xgb_model.deploy(_INSTANCE_COUNT, _INSTANCE_TYPE)

transform_output = f"s3://{sagemaker_session.default_bucket()}/{pipeline_name}Transform"
transformer = Transformer(
model_name=xgb_model.name,
strategy="SingleRecord",
instance_type="ml.m5.xlarge",
instance_count=1,
output_path=transform_output,
sagemaker_session=pipeline_session,
)

transform_input = pipeline_session.upload_data(
path=os.path.join(DATA_DIR, "xgboost_abalone", "abalone"),
key_prefix="integ-test-data/xgboost_abalone/abalone",
)

execution = transformer.transform_with_monitoring(
monitoring_config=data_bias_check_config,
monitoring_resource_config=check_job_config,
data=transform_input,
content_type="text/libsvm",
supplied_baseline_constraints=data_bias_supplied_baseline_constraints,
role=role,
fail_on_violation=False,
)

execution_steps = execution.list_steps()
assert len(execution_steps) == 2

for execution_step in execution_steps:
assert execution_step["StepStatus"] == "Succeeded"

xgb_model.delete_model()
61 changes: 60 additions & 1 deletion tests/unit/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@

from tests.integ import test_local_mode
from tests.unit import SAGEMAKER_CONFIG_TRANSFORM_JOB
from sagemaker.model_monitor import DatasetFormat
from sagemaker.workflow.quality_check_step import (
ModelQualityCheckConfig,
)
from sagemaker.workflow.check_job_config import CheckJobConfig

_CHECK_JOB_PREFIX = "CheckJobPrefix"

ROLE = "DummyRole"
REGION = "us-west-2"
Expand All @@ -49,6 +56,16 @@
"base_transform_job_name": JOB_NAME,
}

PROCESS_REQUEST_ARGS = {
"inputs": "processing_inputs",
"output_config": "output_config",
"job_name": "job_name",
"resources": "resource_config",
"stopping_condition": {"MaxRuntimeInSeconds": 3600},
"app_specification": "app_specification",
"experiment_config": {"ExperimentName": "AnExperiment"},
}

MODEL_DESC_PRIMARY_CONTAINER = {"PrimaryContainer": {"Image": IMAGE_URI}}

MODEL_DESC_CONTAINERS_ONLY = {"Containers": [{"Image": IMAGE_URI}]}
Expand All @@ -72,7 +89,7 @@ def mock_create_tar_file():

@pytest.fixture()
def sagemaker_session():
boto_mock = Mock(name="boto_session")
boto_mock = Mock(name="boto_session", region_name=REGION)
session = Mock(
name="sagemaker_session",
boto_session=boto_mock,
Expand Down Expand Up @@ -764,6 +781,48 @@ def test_stop_transform_job(sagemaker_session, transformer):
sagemaker_session.stop_transform_job.assert_called_once_with(name=JOB_NAME)


@patch("sagemaker.transformer.Transformer._retrieve_image_uri", return_value=IMAGE_URI)
@patch("sagemaker.workflow.pipeline.Pipeline.upsert", return_value={})
@patch("sagemaker.workflow.pipeline.Pipeline.start", return_value=Mock())
def test_transform_with_monitoring_create_and_starts_pipeline(
pipeline_start, upsert, image_uri, sagemaker_session, transformer
):

config = CheckJobConfig(
role=ROLE,
instance_count=1,
instance_type="ml.m5.xlarge",
volume_size_in_gb=60,
max_runtime_in_seconds=1800,
sagemaker_session=sagemaker_session,
base_job_name=_CHECK_JOB_PREFIX,
)

quality_check_config = ModelQualityCheckConfig(
baseline_dataset="s3://baseline_dataset_s3_url",
dataset_format=DatasetFormat.csv(header=True),
problem_type="BinaryClassification",
inference_attribute="quality_cfg_attr_value",
probability_attribute="quality_cfg_attr_value",
ground_truth_attribute="quality_cfg_attr_value",
probability_threshold_attribute="quality_cfg_attr_value",
post_analytics_processor_script="s3://my_bucket/data_quality/postprocessor.py",
output_s3_uri="s3://output_s3_uri",
)

transformer.transform_with_monitoring(
monitoring_config=quality_check_config,
monitoring_resource_config=config,
data=DATA,
content_type="text/libsvm",
supplied_baseline_constraints="supplied_baseline_constraints",
role=ROLE,
)

upsert.assert_called_once()
pipeline_start.assert_called_once()


def test_stop_transform_job_no_transform_job(transformer):
with pytest.raises(ValueError) as e:
transformer.stop_transform_job()
Expand Down