Skip to content

Commit e9719b3

Browse files
committed
feature: add a few minor features to Model Monitoring (aws#268)
* predictor.list_monitors() now returns a list of monitors. * monitor.update/start/stop now waits for schedule to finish updating before returning. * monitor.create will fail if a schedule was already created, explaining to the user that they must first delete a schedule before creating a new one. * fixed shallow copy on env to avoid mutating customer environment dictionary * fix integration tests based on latest API changes * no longer relying on schedule defaults in ITs until API is updated. * add missed imports to init to simplify user experience * remove broken assert because API now randomizes schedule if not provided * fix bug when attaching to monitor with entrypoint not specified
1 parent 2226956 commit e9719b3

19 files changed

+367
-733
lines changed

src/sagemaker/estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -615,8 +615,8 @@ def deploy(
615615
kms_key (str): The ARN of the KMS key that is used to encrypt the
616616
data on the storage volume attached to the instance hosting the
617617
endpoint.
618-
data_capture_config (DataCaptureConfig): Specifies configuration
619-
related to Endpoint data capture for use with
618+
data_capture_config (sagemaker.model_monitor.DataCaptureConfig): Specifies
619+
configuration related to Endpoint data capture for use with
620620
Amazon SageMaker Model Monitoring. Default: None.
621621
**kwargs: Passed to invocation of ``create_model()``.
622622
Implementations may customize ``create_model()`` to accept

src/sagemaker/model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -425,8 +425,8 @@ def deploy(
425425
endpoint.
426426
wait (bool): Whether the call should wait until the deployment of
427427
this model completes (default: True).
428-
data_capture_config (DataCaptureConfig): Specifies configuration
429-
related to Endpoint data capture for use with
428+
data_capture_config (sagemaker.model_monitor.DataCaptureConfig): Specifies
429+
configuration related to Endpoint data capture for use with
430430
Amazon SageMaker Model Monitoring. Default: None.
431431
432432
Returns:
@@ -458,9 +458,9 @@ def deploy(
458458
if self._is_compiled_model and not self.endpoint_name.endswith(compiled_model_suffix):
459459
self.endpoint_name += compiled_model_suffix
460460

461-
data_capture_config_dict = (
462-
data_capture_config.to_request_dict() if data_capture_config else None
463-
)
461+
data_capture_config_dict = None
462+
if data_capture_config is not None:
463+
data_capture_config_dict = data_capture_config.to_request_dict()
464464

465465
if update_endpoint:
466466
endpoint_config_name = self.sagemaker_session.create_endpoint_config(

src/sagemaker/model_monitor/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020

2121
from sagemaker.model_monitor.model_monitoring import ModelMonitor # noqa: F401
2222
from sagemaker.model_monitor.model_monitoring import DefaultModelMonitor # noqa: F401
23+
from sagemaker.model_monitor.model_monitoring import BaseliningJob # noqa: F401
24+
from sagemaker.model_monitor.model_monitoring import MonitoringExecution # noqa: F401
25+
from sagemaker.model_monitor.model_monitoring import EndpointInput # noqa: F401
2326
from sagemaker.model_monitor.model_monitoring import MonitoringOutput # noqa: F401
2427

2528
from sagemaker.model_monitor.cron_expression_generator import CronExpressionGenerator # noqa: F401
@@ -28,5 +31,6 @@
2831
from sagemaker.model_monitor.monitoring_files import ConstraintViolations # noqa: F401
2932

3033
from sagemaker.model_monitor.data_capture_config import DataCaptureConfig # noqa: F401
34+
from sagemaker.model_monitor.dataset_format import DatasetFormat # noqa: F401
3135

3236
from sagemaker.network import NetworkConfig # noqa: F401

src/sagemaker/model_monitor/data_capture_config.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __init__(
4949
sampling_percentage (int): Optional. Default=20. The percentage of data to sample.
5050
Must be between 0 and 100.
5151
destination_s3_uri (str): Optional. Defaults to "s3://<default-session-bucket>/
52-
<model-monitor>/data-capture
52+
model-monitor/data-capture".
5353
kms_key_id (str): Optional. Default=None. The kms key to use when writing to S3.
5454
capture_options ([str]): Optional. Must be a list containing any combination of the
5555
following values: "REQUEST", "RESPONSE". Default=["REQUEST", "RESPONSE"]. Denotes
@@ -78,8 +78,9 @@ def to_request_dict(self):
7878
"InitialSamplingPercentage": self.sampling_percentage,
7979
"DestinationS3Uri": self.destination_s3_uri,
8080
"CaptureOptions": [
81-
{"CaptureMode": dict(self.API_MAPPING).get(capture_option.upper(), capture_option)}
82-
for capture_option in list(self.capture_options)
81+
# Convert to API values or pass value directly through if unable to convert.
82+
{"CaptureMode": self.API_MAPPING.get(capture_option.upper(), capture_option)}
83+
for capture_option in self.capture_options
8384
],
8485
}
8586

src/sagemaker/model_monitor/dataset_format.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def json(lines=True):
4242
"""Returns a DatasetFormat JSON string for use with a DefaultModelMonitor.
4343
4444
Args:
45-
lines (bool): Read the file as a json object per line. Default: True.
45+
lines (bool): Whether the file should be read as a json object per line. Default: True.
4646
4747
Returns:
4848
dict: JSON string containing DatasetFormat to be used by DefaultModelMonitor.

src/sagemaker/model_monitor/model_monitoring.py

Lines changed: 61 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,13 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""This module contains code related to Amazon SageMaker Monitoring Schedules. These
14-
classes assist with suggesting baselines and creating monitoring schedules for data captured
13+
"""This module contains code related to Amazon SageMaker Model Monitoring. These classes
14+
assist with suggesting baselines and creating monitoring schedules for data captured
1515
by SageMaker Endpoints.
1616
"""
1717
from __future__ import print_function, absolute_import
1818

19+
import copy
1920
import json
2021
import os
2122
import logging
@@ -35,10 +36,10 @@
3536
from sagemaker.processing import ProcessingJob
3637
from sagemaker.processing import ProcessingInput
3738
from sagemaker.processing import ProcessingOutput
38-
from sagemaker.model_monitor.cron_expression_generator import CronExpressionGenerator
3939
from sagemaker.model_monitor.monitoring_files import Constraints, ConstraintViolations
4040
from sagemaker.model_monitor.monitoring_files import Statistics
4141
from sagemaker.exceptions import UnexpectedStatusException
42+
from sagemaker.utils import retries
4243

4344
_DEFAULT_MONITOR_IMAGE_URI_WITH_PLACEHOLDERS = (
4445
"{}.dkr.ecr.{}.amazonaws.com/sagemaker-model-monitor-analyzer"
@@ -58,8 +59,7 @@
5859
"ap-northeast-2": "709848358524",
5960
"eu-west-2": "749857270468",
6061
"ap-northeast-1": "574779866223",
61-
"us-west-2": "159807026194", # Prod
62-
# "us-west-2": "894667893881", # Gamma. # TODO-reinvent-2019 [knakad]: Remove this.
62+
"us-west-2": "159807026194",
6363
"us-west-1": "890145073186",
6464
"ap-southeast-1": "245545462676",
6565
"ap-southeast-2": "563025443158",
@@ -86,7 +86,6 @@
8686

8787
_SUGGESTION_JOB_BASE_NAME = "baseline-suggestion-job"
8888
_MONITORING_SCHEDULE_BASE_NAME = "monitoring-schedule"
89-
_SCHEDULE_NAME_SUFFIX = "monitoring-schedule"
9089

9190
_DATASET_SOURCE_PATH_ENV_NAME = "dataset_source"
9291
_DATASET_FORMAT_ENV_NAME = "dataset_format"
@@ -96,7 +95,6 @@
9695
_PUBLISH_CLOUDWATCH_METRICS_ENV_NAME = "publish_cloudwatch_metrics"
9796

9897
_LOGGER = logging.getLogger(__name__)
99-
# TODO-reinvent-2019 [knakad]: Review all docstrings.
10098

10199

102100
class ModelMonitor(object):
@@ -239,8 +237,8 @@ def create_monitoring_schedule(
239237
output,
240238
statistics=None,
241239
constraints=None,
242-
monitor_schedule_name=None, # TODO-reinvent-2019 [knakad]: Change to mon_sched_name evwhere
243-
schedule_cron_expression=CronExpressionGenerator.hourly(),
240+
monitor_schedule_name=None,
241+
schedule_cron_expression=None,
244242
):
245243
"""Creates a monitoring schedule to monitor an Amazon SageMaker Endpoint.
246244
@@ -266,9 +264,18 @@ def create_monitoring_schedule(
266264
a default job name, based on the image name and current timestamp.
267265
schedule_cron_expression (str): The cron expression that dictates the frequency that
268266
this job runs at. See sagemaker.model_monitor.CronExpressionGenerator for valid
269-
expressions.
267+
expressions. Default: Daily.
270268
271269
"""
270+
if self.monitoring_schedule_name is not None:
271+
message = (
272+
"It seems that this object was already used to create an Amazon Model "
273+
"Monitoring Schedule. To create another, first delete the existing one "
274+
"using my_monitor.delete_monitoring_schedule()."
275+
)
276+
print(message)
277+
raise ValueError(message)
278+
272279
self.monitoring_schedule_name = self._generate_monitoring_schedule_name(
273280
schedule_name=monitor_schedule_name
274281
)
@@ -474,23 +481,30 @@ def update_monitoring_schedule(
474481
role_arn=role,
475482
)
476483

484+
self._wait_for_schedule_changes_to_apply()
485+
477486
def start_monitoring_schedule(self):
478487
"""Starts the monitoring schedule."""
479488
self.sagemaker_session.start_monitoring_schedule(
480489
monitoring_schedule_name=self.monitoring_schedule_name
481490
)
482491

492+
self._wait_for_schedule_changes_to_apply()
493+
483494
def stop_monitoring_schedule(self):
484495
"""Stops the monitoring schedule."""
485496
self.sagemaker_session.stop_monitoring_schedule(
486497
monitoring_schedule_name=self.monitoring_schedule_name
487498
)
488499

500+
self._wait_for_schedule_changes_to_apply()
501+
489502
def delete_monitoring_schedule(self):
490503
"""Deletes the monitoring schedule."""
491504
self.sagemaker_session.delete_monitoring_schedule(
492505
monitoring_schedule_name=self.monitoring_schedule_name
493506
)
507+
self.monitoring_schedule_name = None
494508

495509
def baseline_statistics(self, file_name=STATISTICS_JSON_DEFAULT_FILE_NAME):
496510
"""Returns a Statistics object representing the statistics json file generated by the
@@ -665,7 +679,7 @@ def attach(cls, monitor_schedule_name, sagemaker_session=None):
665679
]["ClusterConfig"]["InstanceType"]
666680
entrypoint = schedule_desc["MonitoringScheduleConfig"]["MonitoringJobDefinition"][
667681
"MonitoringAppSpecification"
668-
]["ContainerEntrypoint"]
682+
].get("ContainerEntrypoint")
669683
volume_size_in_gb = schedule_desc["MonitoringScheduleConfig"]["MonitoringJobDefinition"][
670684
"MonitoringResources"
671685
]["ClusterConfig"]["VolumeSizeInGB"]
@@ -744,7 +758,7 @@ def _generate_baselining_job_name(self, job_name=None):
744758
return job_name
745759

746760
if self.base_job_name:
747-
base_name = "{}-{}".format(self.base_job_name, _SCHEDULE_NAME_SUFFIX)
761+
base_name = self.base_job_name
748762
else:
749763
base_name = _SUGGESTION_JOB_BASE_NAME
750764

@@ -932,6 +946,20 @@ def _s3_uri_from_local_path(self, path):
932946
path = os.path.join(s3_uri, os.path.basename(path))
933947
return path
934948

949+
def _wait_for_schedule_changes_to_apply(self):
950+
"""Waits for the schedule associated with this monitor to no longer be in the 'Pending'
951+
state.
952+
953+
"""
954+
for _ in retries(
955+
max_retry_count=36, # 36*5 = 3min
956+
exception_message_prefix="Waiting for schedule to leave 'Pending' status",
957+
seconds_to_sleep=5,
958+
):
959+
schedule_desc = self.describe_schedule()
960+
if schedule_desc["MonitoringScheduleStatus"] != "Pending":
961+
break
962+
935963

936964
class DefaultModelMonitor(ModelMonitor):
937965
"""Sets up Amazon SageMaker Monitoring Schedules and baseline suggestions. Use this class when
@@ -1088,7 +1116,6 @@ def suggest_baseline(
10881116
dataset_format=dataset_format,
10891117
output_path=normalized_baseline_output.source,
10901118
enable_cloudwatch_metrics=False, # Only supported for monitoring schedules
1091-
# TODO-reinvent-2019 [knakad]: Remove this once API stops failing if not provided.
10921119
dataset_source_container_path=baseline_dataset_container_path,
10931120
record_preprocessor_script_container_path=record_preprocessor_script_container_path,
10941121
post_processor_script_container_path=post_processor_script_container_path,
@@ -1147,8 +1174,7 @@ def create_monitoring_schedule(
11471174
constraints=None,
11481175
statistics=None,
11491176
monitor_schedule_name=None,
1150-
schedule_cron_expression=CronExpressionGenerator.hourly(),
1151-
# TODO-reinvent-2019 [knakad]: Service to default this to daily at a random hour
1177+
schedule_cron_expression=None,
11521178
enable_cloudwatch_metrics=True,
11531179
):
11541180
"""Creates a monitoring schedule to monitor an Amazon SageMaker Endpoint.
@@ -1179,11 +1205,20 @@ def create_monitoring_schedule(
11791205
a default job name, based on the image name and current timestamp.
11801206
schedule_cron_expression (str): The cron expression that dictates the frequency that
11811207
this job run. See sagemaker.model_monitor.CronExpressionGenerator for valid
1182-
expressions.
1208+
expressions. Default: Daily.
11831209
enable_cloudwatch_metrics (bool): Whether to publish cloudwatch metrics as part of
11841210
the baselining or monitoring jobs.
11851211
11861212
"""
1213+
if self.monitoring_schedule_name is not None:
1214+
message = (
1215+
"It seems that this object was already used to create an Amazon Model "
1216+
"Monitoring Schedule. To create another, first delete the existing one "
1217+
"using my_monitor.delete_monitoring_schedule()."
1218+
)
1219+
print(message)
1220+
raise ValueError(message)
1221+
11871222
self.monitoring_schedule_name = self._generate_monitoring_schedule_name(
11881223
schedule_name=monitor_schedule_name
11891224
)
@@ -1354,12 +1389,7 @@ def update_monitoring_schedule(
13541389
self.env = env
13551390

13561391
normalized_env = self._generate_env_map(
1357-
env=env,
1358-
# dataset_format=DatasetFormat.sagemaker_capture_json(),
1359-
output_path=output_path,
1360-
enable_cloudwatch_metrics=enable_cloudwatch_metrics,
1361-
# record_preprocessor_script_input=record_preprocessor_script_input,
1362-
# post_analytics_processor_script_input=post_analytics_processor_script_input,
1392+
env=env, output_path=output_path, enable_cloudwatch_metrics=enable_cloudwatch_metrics
13631393
)
13641394

13651395
statistics_object, constraints_object = self._get_baseline_files(
@@ -1422,6 +1452,8 @@ def update_monitoring_schedule(
14221452
role_arn=role,
14231453
)
14241454

1455+
self._wait_for_schedule_changes_to_apply()
1456+
14251457
def run_baseline(self):
14261458
"""'.run_baseline()' is only allowed for ModelMonitor objects. Please use suggest_baseline
14271459
for DefaultModelMonitor objects, instead."""
@@ -1569,8 +1601,8 @@ def latest_monitoring_constraint_violations(self):
15691601
except ClientError:
15701602
status = latest_monitoring_execution.describe()["ProcessingJobStatus"]
15711603
print(
1572-
"Unable to retrieve statistics as job is in status '{}'. Latest violations only "
1573-
"available for completed executions.".format(status)
1604+
"Unable to retrieve constraint violations as job is in status '{}'. Latest "
1605+
"violations only available for completed executions.".format(status)
15741606
)
15751607

15761608
def _normalize_baseline_output(self, output_s3_uri=None):
@@ -1649,7 +1681,7 @@ def _generate_env_map(
16491681
cloudwatch_env_map = {True: "Enabled", False: "Disabled"}
16501682

16511683
if env is not None:
1652-
env = env.copy()
1684+
env = copy.deepcopy(env)
16531685
env = env or {}
16541686

16551687
if output_path is not None:
@@ -1672,12 +1704,6 @@ def _generate_env_map(
16721704
if dataset_source_container_path is not None:
16731705
env[_DATASET_SOURCE_PATH_ENV_NAME] = dataset_source_container_path
16741706

1675-
# if dataset_source_input is not None:
1676-
# dataset_source_input_container_path = os.path.join(
1677-
# dataset_source_input.destination, os.path.basename(dataset_source_input.source)
1678-
# )
1679-
# env[_DATASET_SOURCE_PATH_ENV_NAME] = dataset_source_input_container_path
1680-
16811707
return env
16821708

16831709
def _upload_and_convert_to_processing_input(self, source, destination, name):
@@ -1808,7 +1834,7 @@ def baseline_statistics(self, file_name=STATISTICS_JSON_DEFAULT_FILE_NAME, kms_k
18081834
actual_status=status,
18091835
)
18101836
else:
1811-
raise
1837+
raise client_error
18121838

18131839
def suggested_constraints(self, file_name=CONSTRAINTS_JSON_DEFAULT_FILE_NAME, kms_key=None):
18141840
"""Returns a sagemaker.model_monitor.Constraints object representing the constraints
@@ -1845,7 +1871,7 @@ def suggested_constraints(self, file_name=CONSTRAINTS_JSON_DEFAULT_FILE_NAME, km
18451871
actual_status=status,
18461872
)
18471873
else:
1848-
raise
1874+
raise client_error
18491875

18501876

18511877
class MonitoringExecution(ProcessingJob):
@@ -1956,7 +1982,7 @@ def statistics(self, file_name=STATISTICS_JSON_DEFAULT_FILE_NAME, kms_key=None):
19561982
actual_status=status,
19571983
)
19581984
else:
1959-
raise
1985+
raise client_error
19601986

19611987
def constraint_violations(
19621988
self, file_name=CONSTRAINT_VIOLATIONS_JSON_DEFAULT_FILE_NAME, kms_key=None
@@ -1997,7 +2023,7 @@ def constraint_violations(
19972023
actual_status=status,
19982024
)
19992025
else:
2000-
raise
2026+
raise client_error
20012027

20022028

20032029
class EndpointInput(object):

0 commit comments

Comments
 (0)