52
52
MonitoringAlertActions ,
53
53
ModelDashboardIndicatorAction ,
54
54
)
55
+ from sagemaker .model_monitor .data_quality_monitoring_config import DataQualityMonitoringConfig
55
56
from sagemaker .model_monitor .dataset_format import MonitoringDatasetFormat
56
57
from sagemaker .network import NetworkConfig
57
58
from sagemaker .processing import Processor , ProcessingInput , ProcessingJob , ProcessingOutput
98
99
_INFERENCE_ATTRIBUTE_ENV_NAME = "inference_attribute"
99
100
_PROBABILITY_ATTRIBUTE_ENV_NAME = "probability_attribute"
100
101
_PROBABILITY_THRESHOLD_ATTRIBUTE_ENV_NAME = "probability_threshold_attribute"
102
+ _CATEGORICAL_DRIFT_METHOD_ENV_NAME = "categorical_drift_method"
101
103
102
104
_LOGGER = logging .getLogger (__name__ )
103
105
@@ -1136,6 +1138,7 @@ def _generate_env_map(
1136
1138
probability_attribute = None ,
1137
1139
ground_truth_attribute = None ,
1138
1140
probability_threshold_attribute = None ,
1141
+ categorical_drift_method = None ,
1139
1142
):
1140
1143
"""Generate a list of environment variables from first-class parameters.
1141
1144
@@ -1157,6 +1160,9 @@ def _generate_env_map(
1157
1160
Only used for ModelQualityMonitor.
1158
1161
probability_threshold_attribute (float): threshold to convert probabilities to binaries
1159
1162
Only used for ModelQualityMonitor.
1163
+ categorical_drift_method (str): categorical_drift_method to override the
1164
+ categorical_drift_method of global monitoring_config in constraints
1165
+ suggested by Model Monitor container. Only used for DataQualityMonitor.
1160
1166
1161
1167
Returns:
1162
1168
dict: Dictionary of environment keys and values.
@@ -1206,6 +1212,9 @@ def _generate_env_map(
1206
1212
if probability_threshold_attribute is not None :
1207
1213
env [_PROBABILITY_THRESHOLD_ATTRIBUTE_ENV_NAME ] = probability_threshold_attribute
1208
1214
1215
+ if categorical_drift_method is not None :
1216
+ env [_CATEGORICAL_DRIFT_METHOD_ENV_NAME ] = categorical_drift_method
1217
+
1209
1218
return env
1210
1219
1211
1220
@staticmethod
@@ -1647,6 +1656,7 @@ def suggest_baseline(
1647
1656
wait = True ,
1648
1657
logs = True ,
1649
1658
job_name = None ,
1659
+ monitoring_config_override = None ,
1650
1660
):
1651
1661
"""Suggest baselines for use with Amazon SageMaker Model Monitoring Schedules.
1652
1662
@@ -1666,12 +1676,18 @@ def suggest_baseline(
1666
1676
Only meaningful when wait is True (default: True).
1667
1677
job_name (str): Processing job name. If not specified, the processor generates
1668
1678
a default job name, based on the image name and current timestamp.
1669
-
1679
+ monitoring_config_override (DataQualityMonitoringConfig): monitoring_config object to
1680
+ override the global monitoring_config parameter of constraints suggested by
1681
+ Model Monitor Container. If not specified, the values suggested by container is
1682
+ set.
1670
1683
Returns:
1671
1684
sagemaker.processing.ProcessingJob: The ProcessingJob object representing the
1672
1685
baselining job.
1673
1686
1674
1687
"""
1688
+ if not DataQualityMonitoringConfig .valid_monitoring_config (monitoring_config_override ):
1689
+ raise RuntimeError ("Invalid value for monitoring_config_override." )
1690
+
1675
1691
self .latest_baselining_job_name = self ._generate_baselining_job_name (job_name = job_name )
1676
1692
1677
1693
normalized_baseline_dataset_input = self ._upload_and_convert_to_processing_input (
@@ -1731,6 +1747,11 @@ def suggest_baseline(
1731
1747
1732
1748
normalized_baseline_output = self ._normalize_baseline_output (output_s3_uri = output_s3_uri )
1733
1749
1750
+ categorical_drift_method = None
1751
+ if monitoring_config_override and monitoring_config_override .distribution_constraints :
1752
+ distribution_constraints = monitoring_config_override .distribution_constraints
1753
+ categorical_drift_method = distribution_constraints .categorical_drift_method
1754
+
1734
1755
normalized_env = self ._generate_env_map (
1735
1756
env = self .env ,
1736
1757
dataset_format = dataset_format ,
@@ -1739,6 +1760,7 @@ def suggest_baseline(
1739
1760
dataset_source_container_path = baseline_dataset_container_path ,
1740
1761
record_preprocessor_script_container_path = record_preprocessor_script_container_path ,
1741
1762
post_processor_script_container_path = post_processor_script_container_path ,
1763
+ categorical_drift_method = categorical_drift_method ,
1742
1764
)
1743
1765
1744
1766
baselining_processor = Processor (
0 commit comments