Skip to content

Bugfix: Added support to update arguments in create_monitoring_schedule #3664

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/sagemaker/model_monitor/model_monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def create_monitoring_schedule(
monitor_schedule_name=None,
schedule_cron_expression=None,
batch_transform_input=None,
arguments=None,
):
"""Creates a monitoring schedule to monitor an Amazon SageMaker Endpoint.

Expand Down Expand Up @@ -262,6 +263,7 @@ def create_monitoring_schedule(
batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to
run the monitoring schedule on the batch transform
(default: None)
arguments ([str]): A list of string arguments to be passed to a processing job.

"""
if self.monitoring_schedule_name is not None:
Expand Down Expand Up @@ -326,6 +328,9 @@ def create_monitoring_schedule(
if self.network_config is not None:
network_config_dict = self.network_config._to_request_dict()

if arguments is not None:
self.arguments = arguments

self.sagemaker_session.create_monitoring_schedule(
monitoring_schedule_name=self.monitoring_schedule_name,
schedule_expression=schedule_cron_expression,
Expand Down
86 changes: 86 additions & 0 deletions tests/unit/sagemaker/monitor/test_model_monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
BatchTransformInput,
ModelQualityMonitor,
Statistics,
MonitoringOutput,
)
from sagemaker.model_monitor.monitoring_alert import (
MonitoringAlertSummary,
Expand Down Expand Up @@ -132,6 +133,7 @@
POSTPROCESSOR_URI = "s3://my_bucket/postprocessor.py"
DATA_CAPTURED_S3_URI = "s3://my-bucket/batch-fraud-detection/on-schedule-monitoring/in/"
DATASET_FORMAT = MonitoringDatasetFormat.csv(header=False)
ARGUMENTS = ["test-argument"]
JOB_OUTPUT_CONFIG = {
"MonitoringOutputs": [
{
Expand Down Expand Up @@ -372,6 +374,20 @@
},
"GroundTruthS3Input": {"S3Uri": NEW_GROUND_TRUTH_S3_URI},
}
MODEL_QUALITY_MONITOR_JOB_INPUT = {
"EndpointInput": {
"EndpointName": ENDPOINT_NAME,
"LocalPath": ENDPOINT_INPUT_LOCAL_PATH,
"S3InputMode": S3_INPUT_MODE,
"S3DataDistributionType": S3_DATA_DISTRIBUTION_TYPE,
"StartTimeOffset": START_TIME_OFFSET,
"EndTimeOffset": END_TIME_OFFSET,
"FeaturesAttribute": FEATURES_ATTRIBUTE,
"InferenceAttribute": INFERENCE_ATTRIBUTE,
"ProbabilityAttribute": PROBABILITY_ATTRIBUTE,
"ProbabilityThresholdAttribute": PROBABILITY_THRESHOLD_ATTRIBUTE,
},
}
NEW_MODEL_QUALITY_JOB_DEFINITION = {
"ModelQualityAppSpecification": NEW_MODEL_QUALITY_APP_SPECIFICATION,
"ModelQualityJobInput": NEW_MODEL_QUALITY_JOB_INPUT,
Expand Down Expand Up @@ -483,6 +499,25 @@ def data_quality_monitor(sagemaker_session):
)


@pytest.fixture()
def model_monitor_arguments(sagemaker_session):
return ModelMonitor(
role=ROLE,
instance_count=INSTANCE_COUNT,
instance_type=INSTANCE_TYPE,
volume_size_in_gb=VOLUME_SIZE_IN_GB,
volume_kms_key=VOLUME_KMS_KEY,
output_kms_key=OUTPUT_KMS_KEY,
max_runtime_in_seconds=MAX_RUNTIME_IN_SECONDS,
base_job_name=BASE_JOB_NAME,
sagemaker_session=sagemaker_session,
env=ENVIRONMENT,
tags=TAGS,
network_config=NETWORK_CONFIG,
image_uri=DefaultModelMonitor._get_default_image_uri(REGION),
)


def test_default_model_monitor_suggest_baseline(sagemaker_session):
my_default_monitor = DefaultModelMonitor(
role=ROLE,
Expand Down Expand Up @@ -1691,3 +1726,54 @@ def test_batch_transform_and_endpoint_input_simultaneous_failure(
)
except Exception as e:
assert "Need to have either batch_transform_input or endpoint_input" in str(e)


def test_model_monitor_with_arguments(
model_monitor_arguments,
sagemaker_session,
constraints=None,
statistics=None,
endpoint_input=EndpointInput(
endpoint_name=ENDPOINT_NAME,
destination=ENDPOINT_INPUT_LOCAL_PATH,
start_time_offset=START_TIME_OFFSET,
end_time_offset=END_TIME_OFFSET,
features_attribute=FEATURES_ATTRIBUTE,
inference_attribute=INFERENCE_ATTRIBUTE,
probability_attribute=PROBABILITY_ATTRIBUTE,
probability_threshold_attribute=PROBABILITY_THRESHOLD_ATTRIBUTE,
),
):
# for batch transform input
model_monitor_arguments.create_monitoring_schedule(
constraints=constraints,
statistics=statistics,
monitor_schedule_name=SCHEDULE_NAME,
schedule_cron_expression=CRON_HOURLY,
endpoint_input=endpoint_input,
arguments=ARGUMENTS,
output=MonitoringOutput(source="/opt/ml/processing/output", destination=OUTPUT_S3_URI),
)

sagemaker_session.create_monitoring_schedule.assert_called_with(
monitoring_schedule_name="schedule",
schedule_expression="cron(0 * ? * * *)",
statistics_s3_uri=None,
constraints_s3_uri=None,
monitoring_inputs=[MODEL_QUALITY_MONITOR_JOB_INPUT],
monitoring_output_config=JOB_OUTPUT_CONFIG,
instance_count=INSTANCE_COUNT,
instance_type=INSTANCE_TYPE,
volume_size_in_gb=VOLUME_SIZE_IN_GB,
volume_kms_key=VOLUME_KMS_KEY,
image_uri=DEFAULT_IMAGE_URI,
entrypoint=None,
arguments=ARGUMENTS,
record_preprocessor_source_uri=None,
post_analytics_processor_source_uri=None,
max_runtime_in_seconds=3,
environment=ENVIRONMENT,
network_config=NETWORK_CONFIG._to_request_dict(),
role_arn=ROLE,
tags=TAGS,
)