Skip to content

Commit 3d61dd5

Browse files
kusumbhattKusumlata
and
Kusumlata
authored
feat: Add optional monitoring_config_override parameter in suggest_baseline API (#3939)
Co-authored-by: Kusumlata <[email protected]>
1 parent 566502f commit 3d61dd5

File tree

7 files changed

+616
-2
lines changed

7 files changed

+616
-2
lines changed

doc/api/inference/model_monitor.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,8 @@ Model Monitor
4141
:members:
4242
:undoc-members:
4343
:show-inheritance:
44+
45+
.. automodule:: sagemaker.model_monitor.data_quality_monitoring_config
46+
:members:
47+
:undoc-members:
48+
:show-inheritance:

src/sagemaker/model_monitor/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,10 @@
4646
from sagemaker.model_monitor.dataset_format import MonitoringDatasetFormat # noqa: F401
4747

4848
from sagemaker.network import NetworkConfig # noqa: F401
49+
50+
from sagemaker.model_monitor.data_quality_monitoring_config import ( # noqa: F401
51+
DataQualityMonitoringConfig,
52+
)
53+
from sagemaker.model_monitor.data_quality_monitoring_config import ( # noqa: F401
54+
DataQualityDistributionConstraints,
55+
)
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module contains code related to the MonitoringConfig of constraints file.
14+
15+
Code is used to represent the Monitoring Config object and its parameters suggested
16+
in constraints file by Model Monitor Container in data quality analysis.
17+
"""
18+
from __future__ import print_function, absolute_import
19+
20+
CHI_SQUARED_METHOD = "ChiSquared"
21+
L_INFINITY_METHOD = "LInfinity"
22+
23+
24+
class DataQualityDistributionConstraints:
25+
"""Represents the distribution_constraints object of monitoring_config in constraints file."""
26+
27+
def __init__(self, categorical_drift_method: str = None):
28+
self.categorical_drift_method = categorical_drift_method
29+
30+
@staticmethod
31+
def valid_distribution_constraints(distribution_constraints):
32+
"""Checks whether distribution_constraints are valid or not."""
33+
34+
if not distribution_constraints:
35+
return True
36+
37+
return DataQualityDistributionConstraints.valid_categorical_drift_method(
38+
distribution_constraints.categorical_drift_method
39+
)
40+
41+
@staticmethod
42+
def valid_categorical_drift_method(categorical_drift_method):
43+
"""Checks whether categorical_drift_method is valid or not."""
44+
45+
if not categorical_drift_method:
46+
return True
47+
48+
return categorical_drift_method in [CHI_SQUARED_METHOD, L_INFINITY_METHOD]
49+
50+
51+
class DataQualityMonitoringConfig:
52+
"""Represents monitoring_config object in constraints file."""
53+
54+
def __init__(self, distribution_constraints: DataQualityDistributionConstraints = None):
55+
self.distribution_constraints = distribution_constraints
56+
57+
@staticmethod
58+
def valid_monitoring_config(monitoring_config):
59+
"""Checks whether monitoring_config is valid or not."""
60+
61+
if not monitoring_config:
62+
return True
63+
64+
return DataQualityDistributionConstraints.valid_distribution_constraints(
65+
monitoring_config.distribution_constraints
66+
)

src/sagemaker/model_monitor/model_monitoring.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
MonitoringAlertActions,
5353
ModelDashboardIndicatorAction,
5454
)
55+
from sagemaker.model_monitor.data_quality_monitoring_config import DataQualityMonitoringConfig
5556
from sagemaker.model_monitor.dataset_format import MonitoringDatasetFormat
5657
from sagemaker.network import NetworkConfig
5758
from sagemaker.processing import Processor, ProcessingInput, ProcessingJob, ProcessingOutput
@@ -98,6 +99,7 @@
9899
_INFERENCE_ATTRIBUTE_ENV_NAME = "inference_attribute"
99100
_PROBABILITY_ATTRIBUTE_ENV_NAME = "probability_attribute"
100101
_PROBABILITY_THRESHOLD_ATTRIBUTE_ENV_NAME = "probability_threshold_attribute"
102+
_CATEGORICAL_DRIFT_METHOD_ENV_NAME = "categorical_drift_method"
101103

102104
_LOGGER = logging.getLogger(__name__)
103105

@@ -1136,6 +1138,7 @@ def _generate_env_map(
11361138
probability_attribute=None,
11371139
ground_truth_attribute=None,
11381140
probability_threshold_attribute=None,
1141+
categorical_drift_method=None,
11391142
):
11401143
"""Generate a list of environment variables from first-class parameters.
11411144
@@ -1157,6 +1160,9 @@ def _generate_env_map(
11571160
Only used for ModelQualityMonitor.
11581161
probability_threshold_attribute (float): threshold to convert probabilities to binaries
11591162
Only used for ModelQualityMonitor.
1163+
categorical_drift_method (str): categorical_drift_method to override the
1164+
categorical_drift_method of global monitoring_config in constraints
1165+
suggested by Model Monitor container. Only used for DataQualityMonitor.
11601166
11611167
Returns:
11621168
dict: Dictionary of environment keys and values.
@@ -1206,6 +1212,9 @@ def _generate_env_map(
12061212
if probability_threshold_attribute is not None:
12071213
env[_PROBABILITY_THRESHOLD_ATTRIBUTE_ENV_NAME] = probability_threshold_attribute
12081214

1215+
if categorical_drift_method is not None:
1216+
env[_CATEGORICAL_DRIFT_METHOD_ENV_NAME] = categorical_drift_method
1217+
12091218
return env
12101219

12111220
@staticmethod
@@ -1647,6 +1656,7 @@ def suggest_baseline(
16471656
wait=True,
16481657
logs=True,
16491658
job_name=None,
1659+
monitoring_config_override=None,
16501660
):
16511661
"""Suggest baselines for use with Amazon SageMaker Model Monitoring Schedules.
16521662
@@ -1666,12 +1676,18 @@ def suggest_baseline(
16661676
Only meaningful when wait is True (default: True).
16671677
job_name (str): Processing job name. If not specified, the processor generates
16681678
a default job name, based on the image name and current timestamp.
1669-
1679+
monitoring_config_override (DataQualityMonitoringConfig): monitoring_config object to
1680+
override the global monitoring_config parameter of constraints suggested by
1681+
Model Monitor Container. If not specified, the values suggested by container is
1682+
set.
16701683
Returns:
16711684
sagemaker.processing.ProcessingJob: The ProcessingJob object representing the
16721685
baselining job.
16731686
16741687
"""
1688+
if not DataQualityMonitoringConfig.valid_monitoring_config(monitoring_config_override):
1689+
raise RuntimeError("Invalid value for monitoring_config_override.")
1690+
16751691
self.latest_baselining_job_name = self._generate_baselining_job_name(job_name=job_name)
16761692

16771693
normalized_baseline_dataset_input = self._upload_and_convert_to_processing_input(
@@ -1731,6 +1747,11 @@ def suggest_baseline(
17311747

17321748
normalized_baseline_output = self._normalize_baseline_output(output_s3_uri=output_s3_uri)
17331749

1750+
categorical_drift_method = None
1751+
if monitoring_config_override and monitoring_config_override.distribution_constraints:
1752+
distribution_constraints = monitoring_config_override.distribution_constraints
1753+
categorical_drift_method = distribution_constraints.categorical_drift_method
1754+
17341755
normalized_env = self._generate_env_map(
17351756
env=self.env,
17361757
dataset_format=dataset_format,
@@ -1739,6 +1760,7 @@ def suggest_baseline(
17391760
dataset_source_container_path=baseline_dataset_container_path,
17401761
record_preprocessor_script_container_path=record_preprocessor_script_container_path,
17411762
post_processor_script_container_path=post_processor_script_container_path,
1763+
categorical_drift_method=categorical_drift_method,
17421764
)
17431765

17441766
baselining_processor = Processor(

0 commit comments

Comments
 (0)