diff --git a/src/sagemaker/model_monitor/model_monitoring.py b/src/sagemaker/model_monitor/model_monitoring.py index 2f8266a43a..61b7d8b440 100644 --- a/src/sagemaker/model_monitor/model_monitoring.py +++ b/src/sagemaker/model_monitor/model_monitoring.py @@ -233,6 +233,7 @@ def create_monitoring_schedule( monitor_schedule_name=None, schedule_cron_expression=None, batch_transform_input=None, + arguments=None, ): """Creates a monitoring schedule to monitor an Amazon SageMaker Endpoint. @@ -262,6 +263,7 @@ def create_monitoring_schedule( batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to run the monitoring schedule on the batch transform (default: None) + arguments ([str]): A list of string arguments to be passed to a processing job. """ if self.monitoring_schedule_name is not None: @@ -326,6 +328,9 @@ def create_monitoring_schedule( if self.network_config is not None: network_config_dict = self.network_config._to_request_dict() + if arguments is not None: + self.arguments = arguments + self.sagemaker_session.create_monitoring_schedule( monitoring_schedule_name=self.monitoring_schedule_name, schedule_expression=schedule_cron_expression, diff --git a/tests/unit/sagemaker/monitor/test_model_monitoring.py b/tests/unit/sagemaker/monitor/test_model_monitoring.py index 0766806684..d208907998 100644 --- a/tests/unit/sagemaker/monitor/test_model_monitoring.py +++ b/tests/unit/sagemaker/monitor/test_model_monitoring.py @@ -29,6 +29,7 @@ BatchTransformInput, ModelQualityMonitor, Statistics, + MonitoringOutput, ) from sagemaker.model_monitor.monitoring_alert import ( MonitoringAlertSummary, @@ -132,6 +133,7 @@ POSTPROCESSOR_URI = "s3://my_bucket/postprocessor.py" DATA_CAPTURED_S3_URI = "s3://my-bucket/batch-fraud-detection/on-schedule-monitoring/in/" DATASET_FORMAT = MonitoringDatasetFormat.csv(header=False) +ARGUMENTS = ["test-argument"] JOB_OUTPUT_CONFIG = { "MonitoringOutputs": [ { @@ -372,6 +374,20 @@ }, "GroundTruthS3Input": {"S3Uri": NEW_GROUND_TRUTH_S3_URI}, } +MODEL_QUALITY_MONITOR_JOB_INPUT = { + "EndpointInput": { + "EndpointName": ENDPOINT_NAME, + "LocalPath": ENDPOINT_INPUT_LOCAL_PATH, + "S3InputMode": S3_INPUT_MODE, + "S3DataDistributionType": S3_DATA_DISTRIBUTION_TYPE, + "StartTimeOffset": START_TIME_OFFSET, + "EndTimeOffset": END_TIME_OFFSET, + "FeaturesAttribute": FEATURES_ATTRIBUTE, + "InferenceAttribute": INFERENCE_ATTRIBUTE, + "ProbabilityAttribute": PROBABILITY_ATTRIBUTE, + "ProbabilityThresholdAttribute": PROBABILITY_THRESHOLD_ATTRIBUTE, + }, +} NEW_MODEL_QUALITY_JOB_DEFINITION = { "ModelQualityAppSpecification": NEW_MODEL_QUALITY_APP_SPECIFICATION, "ModelQualityJobInput": NEW_MODEL_QUALITY_JOB_INPUT, @@ -483,6 +499,25 @@ def data_quality_monitor(sagemaker_session): ) +@pytest.fixture() +def model_monitor_arguments(sagemaker_session): + return ModelMonitor( + role=ROLE, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + volume_size_in_gb=VOLUME_SIZE_IN_GB, + volume_kms_key=VOLUME_KMS_KEY, + output_kms_key=OUTPUT_KMS_KEY, + max_runtime_in_seconds=MAX_RUNTIME_IN_SECONDS, + base_job_name=BASE_JOB_NAME, + sagemaker_session=sagemaker_session, + env=ENVIRONMENT, + tags=TAGS, + network_config=NETWORK_CONFIG, + image_uri=DefaultModelMonitor._get_default_image_uri(REGION), + ) + + def test_default_model_monitor_suggest_baseline(sagemaker_session): my_default_monitor = DefaultModelMonitor( role=ROLE, @@ -1691,3 +1726,54 @@ def test_batch_transform_and_endpoint_input_simultaneous_failure( ) except Exception as e: assert "Need to have either batch_transform_input or endpoint_input" in str(e) + + +def test_model_monitor_with_arguments( + model_monitor_arguments, + sagemaker_session, + constraints=None, + statistics=None, + endpoint_input=EndpointInput( + endpoint_name=ENDPOINT_NAME, + destination=ENDPOINT_INPUT_LOCAL_PATH, + start_time_offset=START_TIME_OFFSET, + end_time_offset=END_TIME_OFFSET, + features_attribute=FEATURES_ATTRIBUTE, + inference_attribute=INFERENCE_ATTRIBUTE, + probability_attribute=PROBABILITY_ATTRIBUTE, + probability_threshold_attribute=PROBABILITY_THRESHOLD_ATTRIBUTE, + ), +): + # for batch transform input + model_monitor_arguments.create_monitoring_schedule( + constraints=constraints, + statistics=statistics, + monitor_schedule_name=SCHEDULE_NAME, + schedule_cron_expression=CRON_HOURLY, + endpoint_input=endpoint_input, + arguments=ARGUMENTS, + output=MonitoringOutput(source="/opt/ml/processing/output", destination=OUTPUT_S3_URI), + ) + + sagemaker_session.create_monitoring_schedule.assert_called_with( + monitoring_schedule_name="schedule", + schedule_expression="cron(0 * ? * * *)", + statistics_s3_uri=None, + constraints_s3_uri=None, + monitoring_inputs=[MODEL_QUALITY_MONITOR_JOB_INPUT], + monitoring_output_config=JOB_OUTPUT_CONFIG, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + volume_size_in_gb=VOLUME_SIZE_IN_GB, + volume_kms_key=VOLUME_KMS_KEY, + image_uri=DEFAULT_IMAGE_URI, + entrypoint=None, + arguments=ARGUMENTS, + record_preprocessor_source_uri=None, + post_analytics_processor_source_uri=None, + max_runtime_in_seconds=3, + environment=ENVIRONMENT, + network_config=NETWORK_CONFIG._to_request_dict(), + role_arn=ROLE, + tags=TAGS, + )