Skip to content

Commit 0ab3048

Browse files
keshav-chandakKeshav Chandak
and
Keshav Chandak
authored
fix: added support to update arguments in create_monitoring_schedule (#3664)
Co-authored-by: Keshav Chandak <[email protected]>
1 parent 8dc33e8 commit 0ab3048

File tree

2 files changed

+91
-0
lines changed

2 files changed

+91
-0
lines changed

src/sagemaker/model_monitor/model_monitoring.py

+5
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ def create_monitoring_schedule(
233233
monitor_schedule_name=None,
234234
schedule_cron_expression=None,
235235
batch_transform_input=None,
236+
arguments=None,
236237
):
237238
"""Creates a monitoring schedule to monitor an Amazon SageMaker Endpoint.
238239
@@ -262,6 +263,7 @@ def create_monitoring_schedule(
262263
batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to
263264
run the monitoring schedule on the batch transform
264265
(default: None)
266+
arguments ([str]): A list of string arguments to be passed to a processing job.
265267
266268
"""
267269
if self.monitoring_schedule_name is not None:
@@ -326,6 +328,9 @@ def create_monitoring_schedule(
326328
if self.network_config is not None:
327329
network_config_dict = self.network_config._to_request_dict()
328330

331+
if arguments is not None:
332+
self.arguments = arguments
333+
329334
self.sagemaker_session.create_monitoring_schedule(
330335
monitoring_schedule_name=self.monitoring_schedule_name,
331336
schedule_expression=schedule_cron_expression,

tests/unit/sagemaker/monitor/test_model_monitoring.py

+86
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
BatchTransformInput,
3030
ModelQualityMonitor,
3131
Statistics,
32+
MonitoringOutput,
3233
)
3334
from sagemaker.model_monitor.monitoring_alert import (
3435
MonitoringAlertSummary,
@@ -132,6 +133,7 @@
132133
POSTPROCESSOR_URI = "s3://my_bucket/postprocessor.py"
133134
DATA_CAPTURED_S3_URI = "s3://my-bucket/batch-fraud-detection/on-schedule-monitoring/in/"
134135
DATASET_FORMAT = MonitoringDatasetFormat.csv(header=False)
136+
ARGUMENTS = ["test-argument"]
135137
JOB_OUTPUT_CONFIG = {
136138
"MonitoringOutputs": [
137139
{
@@ -372,6 +374,20 @@
372374
},
373375
"GroundTruthS3Input": {"S3Uri": NEW_GROUND_TRUTH_S3_URI},
374376
}
377+
MODEL_QUALITY_MONITOR_JOB_INPUT = {
378+
"EndpointInput": {
379+
"EndpointName": ENDPOINT_NAME,
380+
"LocalPath": ENDPOINT_INPUT_LOCAL_PATH,
381+
"S3InputMode": S3_INPUT_MODE,
382+
"S3DataDistributionType": S3_DATA_DISTRIBUTION_TYPE,
383+
"StartTimeOffset": START_TIME_OFFSET,
384+
"EndTimeOffset": END_TIME_OFFSET,
385+
"FeaturesAttribute": FEATURES_ATTRIBUTE,
386+
"InferenceAttribute": INFERENCE_ATTRIBUTE,
387+
"ProbabilityAttribute": PROBABILITY_ATTRIBUTE,
388+
"ProbabilityThresholdAttribute": PROBABILITY_THRESHOLD_ATTRIBUTE,
389+
},
390+
}
375391
NEW_MODEL_QUALITY_JOB_DEFINITION = {
376392
"ModelQualityAppSpecification": NEW_MODEL_QUALITY_APP_SPECIFICATION,
377393
"ModelQualityJobInput": NEW_MODEL_QUALITY_JOB_INPUT,
@@ -483,6 +499,25 @@ def data_quality_monitor(sagemaker_session):
483499
)
484500

485501

502+
@pytest.fixture()
503+
def model_monitor_arguments(sagemaker_session):
504+
return ModelMonitor(
505+
role=ROLE,
506+
instance_count=INSTANCE_COUNT,
507+
instance_type=INSTANCE_TYPE,
508+
volume_size_in_gb=VOLUME_SIZE_IN_GB,
509+
volume_kms_key=VOLUME_KMS_KEY,
510+
output_kms_key=OUTPUT_KMS_KEY,
511+
max_runtime_in_seconds=MAX_RUNTIME_IN_SECONDS,
512+
base_job_name=BASE_JOB_NAME,
513+
sagemaker_session=sagemaker_session,
514+
env=ENVIRONMENT,
515+
tags=TAGS,
516+
network_config=NETWORK_CONFIG,
517+
image_uri=DefaultModelMonitor._get_default_image_uri(REGION),
518+
)
519+
520+
486521
def test_default_model_monitor_suggest_baseline(sagemaker_session):
487522
my_default_monitor = DefaultModelMonitor(
488523
role=ROLE,
@@ -1691,3 +1726,54 @@ def test_batch_transform_and_endpoint_input_simultaneous_failure(
16911726
)
16921727
except Exception as e:
16931728
assert "Need to have either batch_transform_input or endpoint_input" in str(e)
1729+
1730+
1731+
def test_model_monitor_with_arguments(
1732+
model_monitor_arguments,
1733+
sagemaker_session,
1734+
constraints=None,
1735+
statistics=None,
1736+
endpoint_input=EndpointInput(
1737+
endpoint_name=ENDPOINT_NAME,
1738+
destination=ENDPOINT_INPUT_LOCAL_PATH,
1739+
start_time_offset=START_TIME_OFFSET,
1740+
end_time_offset=END_TIME_OFFSET,
1741+
features_attribute=FEATURES_ATTRIBUTE,
1742+
inference_attribute=INFERENCE_ATTRIBUTE,
1743+
probability_attribute=PROBABILITY_ATTRIBUTE,
1744+
probability_threshold_attribute=PROBABILITY_THRESHOLD_ATTRIBUTE,
1745+
),
1746+
):
1747+
# for batch transform input
1748+
model_monitor_arguments.create_monitoring_schedule(
1749+
constraints=constraints,
1750+
statistics=statistics,
1751+
monitor_schedule_name=SCHEDULE_NAME,
1752+
schedule_cron_expression=CRON_HOURLY,
1753+
endpoint_input=endpoint_input,
1754+
arguments=ARGUMENTS,
1755+
output=MonitoringOutput(source="/opt/ml/processing/output", destination=OUTPUT_S3_URI),
1756+
)
1757+
1758+
sagemaker_session.create_monitoring_schedule.assert_called_with(
1759+
monitoring_schedule_name="schedule",
1760+
schedule_expression="cron(0 * ? * * *)",
1761+
statistics_s3_uri=None,
1762+
constraints_s3_uri=None,
1763+
monitoring_inputs=[MODEL_QUALITY_MONITOR_JOB_INPUT],
1764+
monitoring_output_config=JOB_OUTPUT_CONFIG,
1765+
instance_count=INSTANCE_COUNT,
1766+
instance_type=INSTANCE_TYPE,
1767+
volume_size_in_gb=VOLUME_SIZE_IN_GB,
1768+
volume_kms_key=VOLUME_KMS_KEY,
1769+
image_uri=DEFAULT_IMAGE_URI,
1770+
entrypoint=None,
1771+
arguments=ARGUMENTS,
1772+
record_preprocessor_source_uri=None,
1773+
post_analytics_processor_source_uri=None,
1774+
max_runtime_in_seconds=3,
1775+
environment=ENVIRONMENT,
1776+
network_config=NETWORK_CONFIG._to_request_dict(),
1777+
role_arn=ROLE,
1778+
tags=TAGS,
1779+
)

0 commit comments

Comments
 (0)