From 88ab7eef9a324008587ea612e1452c56feff15a3 Mon Sep 17 00:00:00 2001 From: Keshav Chandak Date: Wed, 21 Feb 2024 08:24:18 +0000 Subject: [PATCH] fix: fixed implementation of fail_on_violation for transform with monitoring --- src/sagemaker/transformer.py | 7 ++-- tests/integ/test_transformer.py | 64 +++++++++++++++++++++++++++++++++ tests/unit/test_transformer.py | 61 ++++++++++++++++++++++++++++++- 3 files changed, 128 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/transformer.py b/src/sagemaker/transformer.py index 4ddbbc5451..d52bf52186 100644 --- a/src/sagemaker/transformer.py +++ b/src/sagemaker/transformer.py @@ -337,6 +337,7 @@ def transform_with_monitoring( wait: bool = True, pipeline_name: str = None, role: str = None, + fail_on_violation: bool = True, ): """Runs a transform job with monitoring job. @@ -352,7 +353,6 @@ def transform_with_monitoring( ]): the monitoring configuration used for run model monitoring. monitoring_resource_config (`sagemaker.workflow.check_job_config.CheckJobConfig`): the check job (processing job) cluster resource configuration. - transform_step_args (_JobStepArguments): the transform step transform arguments. data (str): Input data location in S3 for the transform job data_type (str): What the S3 location defines (default: 'S3Prefix'). Valid values: @@ -400,8 +400,6 @@ def transform_with_monitoring( monitor_before_transform (bgool): If to run data quality or model explainability monitoring type, a true value of this flag indicates running the check step before the transform job. - fail_on_violation (Union[bool, PipelineVariable]): A opt-out flag to not to fail the - check step when a violation is detected. supplied_baseline_statistics (Union[str, PipelineVariable]): The S3 path to the supplied statistics object representing the statistics JSON file which will be used for drift to check (default: None). @@ -411,6 +409,8 @@ def transform_with_monitoring( wait (bool): To determine if needed to wait for the pipeline execution to complete pipeline_name (str): The name of the Pipeline for the monitoring and transfrom step role (str): Execution role + fail_on_violation (Union[bool, PipelineVariable]): A opt-out flag to not to fail the + check step when a violation is detected. """ transformer = self @@ -454,6 +454,7 @@ def transform_with_monitoring( monitor_before_transform=monitor_before_transform, supplied_baseline_constraints=supplied_baseline_constraints, supplied_baseline_statistics=supplied_baseline_statistics, + fail_on_violation=fail_on_violation, ) pipeline_name = ( diff --git a/tests/integ/test_transformer.py b/tests/integ/test_transformer.py index c1fa2f15d4..d25f45d4db 100644 --- a/tests/integ/test_transformer.py +++ b/tests/integ/test_transformer.py @@ -709,3 +709,67 @@ def test_transformer_and_monitoring_job( assert execution_step["StepStatus"] == "Succeeded" xgb_model.delete_model() + + +def test_transformer_and_monitoring_job_to_pass_with_no_failure_in_violation( + pipeline_session, + sagemaker_session, + role, + pipeline_name, + check_job_config, + data_bias_check_config, +): + xgb_model_data_s3 = pipeline_session.upload_data( + path=os.path.join(os.path.join(DATA_DIR, "xgboost_abalone"), "xgb_model.tar.gz"), + key_prefix="integ-test-data/xgboost/model", + ) + data_bias_supplied_baseline_constraints = Constraints.from_file_path( + constraints_file_path=os.path.join( + DATA_DIR, "pipeline/clarify_check_step/data_bias/bad_cases/analysis.json" + ), + sagemaker_session=sagemaker_session, + ).file_s3_uri + + xgb_model = XGBoostModel( + model_data=xgb_model_data_s3, + framework_version="1.3-1", + role=role, + sagemaker_session=sagemaker_session, + entry_point=os.path.join(os.path.join(DATA_DIR, "xgboost_abalone"), "inference.py"), + enable_network_isolation=True, + ) + + xgb_model.deploy(_INSTANCE_COUNT, _INSTANCE_TYPE) + + transform_output = f"s3://{sagemaker_session.default_bucket()}/{pipeline_name}Transform" + transformer = Transformer( + model_name=xgb_model.name, + strategy="SingleRecord", + instance_type="ml.m5.xlarge", + instance_count=1, + output_path=transform_output, + sagemaker_session=pipeline_session, + ) + + transform_input = pipeline_session.upload_data( + path=os.path.join(DATA_DIR, "xgboost_abalone", "abalone"), + key_prefix="integ-test-data/xgboost_abalone/abalone", + ) + + execution = transformer.transform_with_monitoring( + monitoring_config=data_bias_check_config, + monitoring_resource_config=check_job_config, + data=transform_input, + content_type="text/libsvm", + supplied_baseline_constraints=data_bias_supplied_baseline_constraints, + role=role, + fail_on_violation=False, + ) + + execution_steps = execution.list_steps() + assert len(execution_steps) == 2 + + for execution_step in execution_steps: + assert execution_step["StepStatus"] == "Succeeded" + + xgb_model.delete_model() diff --git a/tests/unit/test_transformer.py b/tests/unit/test_transformer.py index 8497bc7ea0..e744e461a1 100644 --- a/tests/unit/test_transformer.py +++ b/tests/unit/test_transformer.py @@ -23,6 +23,13 @@ from tests.integ import test_local_mode from tests.unit import SAGEMAKER_CONFIG_TRANSFORM_JOB +from sagemaker.model_monitor import DatasetFormat +from sagemaker.workflow.quality_check_step import ( + ModelQualityCheckConfig, +) +from sagemaker.workflow.check_job_config import CheckJobConfig + +_CHECK_JOB_PREFIX = "CheckJobPrefix" ROLE = "DummyRole" REGION = "us-west-2" @@ -49,6 +56,16 @@ "base_transform_job_name": JOB_NAME, } +PROCESS_REQUEST_ARGS = { + "inputs": "processing_inputs", + "output_config": "output_config", + "job_name": "job_name", + "resources": "resource_config", + "stopping_condition": {"MaxRuntimeInSeconds": 3600}, + "app_specification": "app_specification", + "experiment_config": {"ExperimentName": "AnExperiment"}, +} + MODEL_DESC_PRIMARY_CONTAINER = {"PrimaryContainer": {"Image": IMAGE_URI}} MODEL_DESC_CONTAINERS_ONLY = {"Containers": [{"Image": IMAGE_URI}]} @@ -72,7 +89,7 @@ def mock_create_tar_file(): @pytest.fixture() def sagemaker_session(): - boto_mock = Mock(name="boto_session") + boto_mock = Mock(name="boto_session", region_name=REGION) session = Mock( name="sagemaker_session", boto_session=boto_mock, @@ -764,6 +781,48 @@ def test_stop_transform_job(sagemaker_session, transformer): sagemaker_session.stop_transform_job.assert_called_once_with(name=JOB_NAME) +@patch("sagemaker.transformer.Transformer._retrieve_image_uri", return_value=IMAGE_URI) +@patch("sagemaker.workflow.pipeline.Pipeline.upsert", return_value={}) +@patch("sagemaker.workflow.pipeline.Pipeline.start", return_value=Mock()) +def test_transform_with_monitoring_create_and_starts_pipeline( + pipeline_start, upsert, image_uri, sagemaker_session, transformer +): + + config = CheckJobConfig( + role=ROLE, + instance_count=1, + instance_type="ml.m5.xlarge", + volume_size_in_gb=60, + max_runtime_in_seconds=1800, + sagemaker_session=sagemaker_session, + base_job_name=_CHECK_JOB_PREFIX, + ) + + quality_check_config = ModelQualityCheckConfig( + baseline_dataset="s3://baseline_dataset_s3_url", + dataset_format=DatasetFormat.csv(header=True), + problem_type="BinaryClassification", + inference_attribute="quality_cfg_attr_value", + probability_attribute="quality_cfg_attr_value", + ground_truth_attribute="quality_cfg_attr_value", + probability_threshold_attribute="quality_cfg_attr_value", + post_analytics_processor_script="s3://my_bucket/data_quality/postprocessor.py", + output_s3_uri="s3://output_s3_uri", + ) + + transformer.transform_with_monitoring( + monitoring_config=quality_check_config, + monitoring_resource_config=config, + data=DATA, + content_type="text/libsvm", + supplied_baseline_constraints="supplied_baseline_constraints", + role=ROLE, + ) + + upsert.assert_called_once() + pipeline_start.assert_called_once() + + def test_stop_transform_job_no_transform_job(transformer): with pytest.raises(ValueError) as e: transformer.stop_transform_job()