diff --git a/src/sagemaker/model_monitor/model_monitoring.py b/src/sagemaker/model_monitor/model_monitoring.py index 46481e9fbe..e865b4815f 100644 --- a/src/sagemaker/model_monitor/model_monitoring.py +++ b/src/sagemaker/model_monitor/model_monitoring.py @@ -433,6 +433,7 @@ def update_monitoring_schedule( network_config=None, role=None, image_uri=None, + batch_transform_input=None, ): """Updates the existing monitoring schedule. @@ -475,14 +476,29 @@ def update_monitoring_schedule( role (str): An AWS IAM role name or ARN. The Amazon SageMaker jobs use this role. image_uri (str): The uri of the image to use for the jobs started by the Monitor. + batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to + run the monitoring schedule on the batch transform (default: None) """ monitoring_inputs = None + + if (batch_transform_input is not None) and (endpoint_input is not None): + message = ( + "Cannot update both batch_transform_input and endpoint_input to update an " + "Amazon Model Monitoring Schedule. " + "Please provide atmost one of the above required inputs" + ) + _LOGGER.error(message) + raise ValueError(message) + if endpoint_input is not None: monitoring_inputs = [ self._normalize_endpoint_input(endpoint_input=endpoint_input)._to_request_dict() ] + elif batch_transform_input is not None: + monitoring_inputs = [batch_transform_input._to_request_dict()] + monitoring_output_config = None if output is not None: normalized_monitoring_output = self._normalize_monitoring_output_fields(output=output) @@ -1895,6 +1911,7 @@ def update_monitoring_schedule( network_config=None, enable_cloudwatch_metrics=None, role=None, + batch_transform_input=None, ): """Updates the existing monitoring schedule. @@ -1936,8 +1953,20 @@ def update_monitoring_schedule( enable_cloudwatch_metrics (bool): Whether to publish cloudwatch metrics as part of the baselining or monitoring jobs. role (str): An AWS IAM role name or ARN. The Amazon SageMaker jobs use this role. + batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to + run the monitoring schedule on the batch transform (default: None) """ + + if (batch_transform_input is not None) and (endpoint_input is not None): + message = ( + "Cannot update both batch_transform_input and endpoint_input to update an " + "Amazon Model Monitoring Schedule. " + "Please provide atmost one of the above required inputs" + ) + _LOGGER.error(message) + raise ValueError(message) + # check if this schedule is in v2 format and update as per v2 format if it is if self.job_definition_name is not None: self._update_data_quality_monitoring_schedule( @@ -1958,6 +1987,7 @@ def update_monitoring_schedule( network_config=network_config, enable_cloudwatch_metrics=enable_cloudwatch_metrics, role=role, + batch_transform_input=batch_transform_input, ) return @@ -1965,6 +1995,9 @@ def update_monitoring_schedule( if endpoint_input is not None: monitoring_inputs = [self._normalize_endpoint_input(endpoint_input)._to_request_dict()] + elif batch_transform_input is not None: + monitoring_inputs = [batch_transform_input._to_request_dict()] + record_preprocessor_script_s3_uri = None if record_preprocessor_script is not None: record_preprocessor_script_s3_uri = self._s3_uri_from_local_path( @@ -3022,6 +3055,15 @@ def update_monitoring_schedule( self._update_monitoring_schedule(self.job_definition_name, schedule_cron_expression) return + if (batch_transform_input is not None) and (endpoint_input is not None): + message = ( + "Cannot update both batch_transform_input and endpoint_input to update an " + "Amazon Model Monitoring Schedule. " + "Please provide atmost one of the above required inputs" + ) + _LOGGER.error(message) + raise ValueError(message) + # Need to update schedule with a new job definition job_desc = self.sagemaker_session.sagemaker_client.describe_model_quality_job_definition( JobDefinitionName=self.job_definition_name diff --git a/tests/integ/test_model_monitor.py b/tests/integ/test_model_monitor.py index 0b23efdc4c..f0db17e63c 100644 --- a/tests/integ/test_model_monitor.py +++ b/tests/integ/test_model_monitor.py @@ -281,7 +281,6 @@ def updated_output_kms_key(sagemaker_session): tests.integ.test_region() in tests.integ.NO_MODEL_MONITORING_REGIONS, reason="ModelMonitoring is not yet supported in this region.", ) -@pytest.mark.release def test_default_monitoring_batch_transform_schedule_name( sagemaker_session, output_kms_key, volume_kms_key ): @@ -359,7 +358,6 @@ def test_default_monitoring_batch_transform_schedule_name( tests.integ.test_region() in tests.integ.NO_MODEL_MONITORING_REGIONS, reason="ModelMonitoring is not yet supported in this region.", ) -@pytest.mark.release def test_default_monitor_suggest_baseline_and_create_monitoring_schedule_with_customizations( sagemaker_session, output_kms_key, volume_kms_key, predictor ): @@ -1852,3 +1850,195 @@ def _verify_default_monitoring_schedule_with_batch_transform( ) else: assert network_config is None + + +def test_default_update_monitoring_batch_transform( + sagemaker_session, output_kms_key, volume_kms_key +): + my_default_monitor = DefaultModelMonitor( + role=ROLE, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + volume_size_in_gb=VOLUME_SIZE_IN_GB, + volume_kms_key=volume_kms_key, + output_kms_key=output_kms_key, + max_runtime_in_seconds=MAX_RUNTIME_IN_SECONDS, + sagemaker_session=sagemaker_session, + env=ENVIRONMENT, + tags=TAGS, + network_config=NETWORK_CONFIG, + ) + + output_s3_uri = os.path.join( + "s3://", + sagemaker_session.default_bucket(), + "integ-test-monitoring-output-bucket", + str(uuid.uuid4()), + ) + + data_captured_destination_s3_uri = os.path.join( + "s3://", + sagemaker_session.default_bucket(), + "sagemaker-serving-batch-transform", + str(uuid.uuid4()), + ) + + batch_transform_input = BatchTransformInput( + data_captured_destination_s3_uri=data_captured_destination_s3_uri, + destination="/opt/ml/processing/output", + dataset_format=MonitoringDatasetFormat.csv(header=False), + ) + + statistics = Statistics.from_file_path( + statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json"), + sagemaker_session=sagemaker_session, + ) + + constraints = Constraints.from_file_path( + constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json"), + sagemaker_session=sagemaker_session, + ) + + my_default_monitor.create_monitoring_schedule( + batch_transform_input=batch_transform_input, + output_s3_uri=output_s3_uri, + statistics=statistics, + constraints=constraints, + schedule_cron_expression=HOURLY_CRON_EXPRESSION, + enable_cloudwatch_metrics=ENABLE_CLOUDWATCH_METRICS, + ) + + _wait_for_schedule_changes_to_apply(monitor=my_default_monitor) + + data_captured_destination_s3_uri = os.path.join( + "s3://", + sagemaker_session.default_bucket(), + "sagemaker-tensorflow-serving-batch-transform", + str(uuid.uuid4()), + ) + + batch_transform_input = BatchTransformInput( + data_captured_destination_s3_uri=data_captured_destination_s3_uri, + destination="/opt/ml/processing/output", + dataset_format=MonitoringDatasetFormat.csv(header=False), + ) + + my_default_monitor.update_monitoring_schedule( + batch_transform_input=batch_transform_input, + ) + + _wait_for_schedule_changes_to_apply(monitor=my_default_monitor) + + schedule_description = my_default_monitor.describe_schedule() + + _verify_default_monitoring_schedule_with_batch_transform( + sagemaker_session=sagemaker_session, + schedule_description=schedule_description, + cron_expression=HOURLY_CRON_EXPRESSION, + statistics=statistics, + constraints=constraints, + output_kms_key=output_kms_key, + volume_kms_key=volume_kms_key, + network_config=NETWORK_CONFIG, + ) + + my_default_monitor.stop_monitoring_schedule() + my_default_monitor.delete_monitoring_schedule() + + +def test_byoc_monitoring_schedule_name_update_batch( + sagemaker_session, output_kms_key, volume_kms_key +): + byoc_env = ENVIRONMENT.copy() + byoc_env["dataset_format"] = json.dumps(DatasetFormat.csv(header=False)) + byoc_env["dataset_source"] = "/opt/ml/processing/input/baseline_dataset_input" + byoc_env["output_path"] = os.path.join("/opt/ml/processing/output") + byoc_env["publish_cloudwatch_metrics"] = "Disabled" + + my_byoc_monitor = ModelMonitor( + role=ROLE, + image_uri=DefaultModelMonitor._get_default_image_uri( + sagemaker_session.boto_session.region_name + ), + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + volume_size_in_gb=VOLUME_SIZE_IN_GB, + volume_kms_key=volume_kms_key, + output_kms_key=output_kms_key, + max_runtime_in_seconds=MAX_RUNTIME_IN_SECONDS, + sagemaker_session=sagemaker_session, + env=byoc_env, + tags=TAGS, + network_config=NETWORK_CONFIG, + ) + + output_s3_uri = os.path.join( + "s3://", + sagemaker_session.default_bucket(), + "integ-test-monitoring-output-bucket", + str(uuid.uuid4()), + ) + + statistics = Statistics.from_file_path( + statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json"), + sagemaker_session=sagemaker_session, + ) + + constraints = Constraints.from_file_path( + constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json"), + sagemaker_session=sagemaker_session, + ) + + data_captured_destination_s3_uri = os.path.join( + "s3://", + sagemaker_session.default_bucket(), + "sagemaker-serving-batch-transform", + str(uuid.uuid4()), + ) + + batch_transform_input = BatchTransformInput( + data_captured_destination_s3_uri=data_captured_destination_s3_uri, + destination="/opt/ml/processing/output", + dataset_format=MonitoringDatasetFormat.csv(header=False), + ) + + my_byoc_monitor.create_monitoring_schedule( + endpoint_input=batch_transform_input, + output=MonitoringOutput(source="/opt/ml/processing/output", destination=output_s3_uri), + statistics=statistics, + constraints=constraints, + schedule_cron_expression=HOURLY_CRON_EXPRESSION, + ) + + _wait_for_schedule_changes_to_apply(monitor=my_byoc_monitor) + + data_captured_destination_s3_uri = os.path.join( + "s3://", + sagemaker_session.default_bucket(), + "sagemaker-tensorflow-serving-batch-transform", + str(uuid.uuid4()), + ) + + batch_transform_input = BatchTransformInput( + data_captured_destination_s3_uri=data_captured_destination_s3_uri, + destination="/opt/ml/processing/output", + dataset_format=MonitoringDatasetFormat.csv(header=False), + ) + + my_byoc_monitor.update_monitoring_schedule( + batch_transform_input=batch_transform_input, + ) + + _wait_for_schedule_changes_to_apply(monitor=my_byoc_monitor) + + schedule_description = my_byoc_monitor.describe_schedule() + + assert ( + data_captured_destination_s3_uri + == schedule_description["MonitoringScheduleConfig"]["MonitoringJobDefinition"][ + "MonitoringInputs" + ][0]["BatchTransformInput"]["DataCapturedDestinationS3Uri"] + ) + + my_byoc_monitor.stop_monitoring_schedule() + my_byoc_monitor.delete_monitoring_schedule() diff --git a/tests/unit/sagemaker/monitor/test_model_monitoring.py b/tests/unit/sagemaker/monitor/test_model_monitoring.py index 77eb8ae506..0a00a1b7cc 100644 --- a/tests/unit/sagemaker/monitor/test_model_monitoring.py +++ b/tests/unit/sagemaker/monitor/test_model_monitoring.py @@ -1836,3 +1836,47 @@ def test_model_monitor_with_arguments( role_arn=ROLE, tags=TAGS, ) + + +def test_update_model_monitor_error_with_endpoint_and_batch( + model_monitor_arguments, + data_quality_monitor, + endpoint_input=EndpointInput( + endpoint_name=ENDPOINT_NAME, + destination=ENDPOINT_INPUT_LOCAL_PATH, + start_time_offset=START_TIME_OFFSET, + end_time_offset=END_TIME_OFFSET, + features_attribute=FEATURES_ATTRIBUTE, + inference_attribute=INFERENCE_ATTRIBUTE, + probability_attribute=PROBABILITY_ATTRIBUTE, + probability_threshold_attribute=PROBABILITY_THRESHOLD_ATTRIBUTE, + ), + batch_transform_input=BatchTransformInput( + data_captured_destination_s3_uri=DATA_CAPTURED_S3_URI, + destination=SCHEDULE_DESTINATION, + dataset_format=MonitoringDatasetFormat.csv(header=False), + ), +): + try: + model_monitor_arguments.update_monitoring_schedule( + schedule_cron_expression=CRON_HOURLY, + endpoint_input=endpoint_input, + arguments=ARGUMENTS, + output=MonitoringOutput(source="/opt/ml/processing/output", destination=OUTPUT_S3_URI), + batch_transform_input=batch_transform_input, + ) + except ValueError as error: + assert "Cannot update both batch_transform_input and endpoint_input to update an" in str( + error + ) + + try: + data_quality_monitor.update_monitoring_schedule( + schedule_cron_expression=CRON_HOURLY, + endpoint_input=endpoint_input, + batch_transform_input=batch_transform_input, + ) + except ValueError as error: + assert "Cannot update both batch_transform_input and endpoint_input to update an" in str( + error + )