Skip to content

Commit 2416f29

Browse files
author
Keshav Chandak
committed
feature: added support for batch transform with model monitoring
1 parent 0914f17 commit 2416f29

14 files changed

+1206
-11
lines changed

src/sagemaker/inputs.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ class TransformInput(object):
176176
output_filter: str = attr.ib(default=None)
177177
join_source: str = attr.ib(default=None)
178178
model_client_config: dict = attr.ib(default=None)
179+
batch_data_capture_config: dict = attr.ib(default=None)
179180

180181

181182
class FileSystemInput(object):
@@ -232,3 +233,43 @@ def __init__(
232233

233234
if content_type:
234235
self.config["ContentType"] = content_type
236+
237+
238+
class BatchDataCaptureConfig(object):
239+
"""Configuration object passed in when create a batch transform job.
240+
241+
Specifies configuration related to batch transform job data capture for use with
242+
Amazon SageMaker Model Monitoring
243+
"""
244+
245+
def __init__(
246+
self,
247+
destination_s3_uri: str,
248+
kms_key_id: str = None,
249+
generate_inference_id: bool = None,
250+
):
251+
"""Create new BatchDataCaptureConfig
252+
253+
Args:
254+
destination_s3_uri (str): S3 Location to store the captured data
255+
kms_key_id (str): The KMS key to use when writing to S3.
256+
KmsKeyId can be an ID of a KMS key, ARN of a KMS key, alias of a KMS key,
257+
or alias of a KMS key. The KmsKeyId is applied to all outputs.
258+
generate_inference_id (bool): Flag to generate an inference id
259+
"""
260+
self.destination_s3_uri = destination_s3_uri
261+
self.kms_key_id = kms_key_id
262+
self.generate_inference_id = generate_inference_id
263+
264+
def _to_request_dict(self):
265+
"""Generates a request dictionary using the parameters provided to the class."""
266+
batch_data_capture_config = {
267+
"DestinationS3Uri": self.destination_s3_uri,
268+
}
269+
270+
if self.kms_key_id is not None:
271+
batch_data_capture_config["KmsKeyId"] = self.kms_key_id
272+
if self.generate_inference_id is not None:
273+
batch_data_capture_config["GenerateInferenceId"] = self.generate_inference_id
274+
275+
return batch_data_capture_config

src/sagemaker/model_monitor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from sagemaker.model_monitor.model_monitoring import BaseliningJob # noqa: F401
2424
from sagemaker.model_monitor.model_monitoring import MonitoringExecution # noqa: F401
2525
from sagemaker.model_monitor.model_monitoring import EndpointInput # noqa: F401
26+
from sagemaker.model_monitor.model_monitoring import BatchTransformInput # noqa: F401
2627
from sagemaker.model_monitor.model_monitoring import MonitoringOutput # noqa: F401
2728
from sagemaker.model_monitor.model_monitoring import ModelQualityMonitor # noqa: F401
2829

@@ -42,5 +43,6 @@
4243

4344
from sagemaker.model_monitor.data_capture_config import DataCaptureConfig # noqa: F401
4445
from sagemaker.model_monitor.dataset_format import DatasetFormat # noqa: F401
46+
from sagemaker.model_monitor.dataset_format import MonitoringDatasetFormat # noqa: F401
4547

4648
from sagemaker.network import NetworkConfig # noqa: F401

src/sagemaker/model_monitor/clarify_model_monitoring.py

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ def _build_create_job_definition_request(
227227
env=None,
228228
tags=None,
229229
network_config=None,
230+
batch_transform_input=None,
230231
):
231232
"""Build the request for job definition creation API
232233
@@ -270,6 +271,8 @@ def _build_create_job_definition_request(
270271
network_config (sagemaker.network.NetworkConfig): A NetworkConfig
271272
object that configures network isolation, encryption of
272273
inter-container traffic, security group IDs, and subnets.
274+
batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to run
275+
the monitoring schedule on the batch transform
273276
274277
Returns:
275278
dict: request parameters to create job definition.
@@ -366,6 +369,27 @@ def _build_create_job_definition_request(
366369
latest_baselining_job_config.probability_threshold_attribute
367370
)
368371
job_input = normalized_endpoint_input._to_request_dict()
372+
elif batch_transform_input is not None:
373+
# backfill attributes to batch transform input
374+
if latest_baselining_job_config is not None:
375+
if batch_transform_input.features_attribute is None:
376+
batch_transform_input.features_attribute = (
377+
latest_baselining_job_config.features_attribute
378+
)
379+
if batch_transform_input.inference_attribute is None:
380+
batch_transform_input.inference_attribute = (
381+
latest_baselining_job_config.inference_attribute
382+
)
383+
if batch_transform_input.probability_attribute is None:
384+
batch_transform_input.probability_attribute = (
385+
latest_baselining_job_config.probability_attribute
386+
)
387+
if batch_transform_input.probability_threshold_attribute is None:
388+
batch_transform_input.probability_threshold_attribute = (
389+
latest_baselining_job_config.probability_threshold_attribute
390+
)
391+
job_input = batch_transform_input._to_request_dict()
392+
369393
if ground_truth_input is not None:
370394
job_input["GroundTruthS3Input"] = dict(S3Uri=ground_truth_input)
371395

@@ -500,14 +524,15 @@ def suggest_baseline(
500524
# noinspection PyMethodOverriding
501525
def create_monitoring_schedule(
502526
self,
503-
endpoint_input,
504527
ground_truth_input,
528+
endpoint_input=None,
505529
analysis_config=None,
506530
output_s3_uri=None,
507531
constraints=None,
508532
monitor_schedule_name=None,
509533
schedule_cron_expression=None,
510534
enable_cloudwatch_metrics=True,
535+
batch_transform_input=None,
511536
):
512537
"""Creates a monitoring schedule.
513538
@@ -530,6 +555,8 @@ def create_monitoring_schedule(
530555
expressions. Default: Daily.
531556
enable_cloudwatch_metrics (bool): Whether to publish cloudwatch metrics as part of
532557
the baselining or monitoring jobs.
558+
batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to run
559+
the monitoring schedule on the batch transform
533560
"""
534561
if self.job_definition_name is not None or self.monitoring_schedule_name is not None:
535562
message = (
@@ -540,6 +567,15 @@ def create_monitoring_schedule(
540567
_LOGGER.error(message)
541568
raise ValueError(message)
542569

570+
if (batch_transform_input is not None) ^ (endpoint_input is None):
571+
message = (
572+
"Need to have either batch_transform_input or endpoint_input to create an "
573+
"Amazon Model Monitoring Schedule. "
574+
"Please provide only one of the above required inputs"
575+
)
576+
_LOGGER.error(message)
577+
raise ValueError(message)
578+
543579
# create job definition
544580
monitor_schedule_name = self._generate_monitoring_schedule_name(
545581
schedule_name=monitor_schedule_name
@@ -569,6 +605,7 @@ def create_monitoring_schedule(
569605
env=self.env,
570606
tags=self.tags,
571607
network_config=self.network_config,
608+
batch_transform_input=batch_transform_input,
572609
)
573610
self.sagemaker_session.sagemaker_client.create_model_bias_job_definition(**request_dict)
574611

@@ -612,6 +649,7 @@ def update_monitoring_schedule(
612649
max_runtime_in_seconds=None,
613650
env=None,
614651
network_config=None,
652+
batch_transform_input=None,
615653
):
616654
"""Updates the existing monitoring schedule.
617655
@@ -651,6 +689,8 @@ def update_monitoring_schedule(
651689
network_config (sagemaker.network.NetworkConfig): A NetworkConfig
652690
object that configures network isolation, encryption of
653691
inter-container traffic, security group IDs, and subnets.
692+
batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to run
693+
the monitoring schedule on the batch transform
654694
"""
655695
valid_args = {
656696
arg: value for arg, value in locals().items() if arg != "self" and value is not None
@@ -660,6 +700,15 @@ def update_monitoring_schedule(
660700
if len(valid_args) <= 0:
661701
return
662702

703+
if batch_transform_input is not None and endpoint_input is not None:
704+
message = (
705+
"Need to have either batch_transform_input or endpoint_input to create an "
706+
"Amazon Model Monitoring Schedule. "
707+
"Please provide only one of the above required inputs"
708+
)
709+
_LOGGER.error(message)
710+
raise ValueError(message)
711+
663712
# Only need to update schedule expression
664713
if len(valid_args) == 1 and schedule_cron_expression is not None:
665714
self._update_monitoring_schedule(self.job_definition_name, schedule_cron_expression)
@@ -691,6 +740,7 @@ def update_monitoring_schedule(
691740
env=env,
692741
tags=self.tags,
693742
network_config=network_config,
743+
batch_transform_input=batch_transform_input,
694744
)
695745
self.sagemaker_session.sagemaker_client.create_model_bias_job_definition(**request_dict)
696746
try:
@@ -895,13 +945,14 @@ def suggest_baseline(
895945
# noinspection PyMethodOverriding
896946
def create_monitoring_schedule(
897947
self,
898-
endpoint_input,
948+
endpoint_input=None,
899949
analysis_config=None,
900950
output_s3_uri=None,
901951
constraints=None,
902952
monitor_schedule_name=None,
903953
schedule_cron_expression=None,
904954
enable_cloudwatch_metrics=True,
955+
batch_transform_input=None,
905956
):
906957
"""Creates a monitoring schedule.
907958
@@ -923,6 +974,8 @@ def create_monitoring_schedule(
923974
expressions. Default: Daily.
924975
enable_cloudwatch_metrics (bool): Whether to publish cloudwatch metrics as part of
925976
the baselining or monitoring jobs.
977+
batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to
978+
run the monitoring schedule on the batch transform
926979
"""
927980
if self.job_definition_name is not None or self.monitoring_schedule_name is not None:
928981
message = (
@@ -933,6 +986,15 @@ def create_monitoring_schedule(
933986
_LOGGER.error(message)
934987
raise ValueError(message)
935988

989+
if (batch_transform_input is not None) ^ (endpoint_input is None):
990+
message = (
991+
"Need to have either batch_transform_input or endpoint_input to create an "
992+
"Amazon Model Monitoring Schedule."
993+
"Please provide only one of the above required inputs"
994+
)
995+
_LOGGER.error(message)
996+
raise ValueError(message)
997+
936998
# create job definition
937999
monitor_schedule_name = self._generate_monitoring_schedule_name(
9381000
schedule_name=monitor_schedule_name
@@ -961,6 +1023,7 @@ def create_monitoring_schedule(
9611023
env=self.env,
9621024
tags=self.tags,
9631025
network_config=self.network_config,
1026+
batch_transform_input=batch_transform_input,
9641027
)
9651028
self.sagemaker_session.sagemaker_client.create_model_explainability_job_definition(
9661029
**request_dict
@@ -1005,6 +1068,7 @@ def update_monitoring_schedule(
10051068
max_runtime_in_seconds=None,
10061069
env=None,
10071070
network_config=None,
1071+
batch_transform_input=None,
10081072
):
10091073
"""Updates the existing monitoring schedule.
10101074
@@ -1043,6 +1107,8 @@ def update_monitoring_schedule(
10431107
network_config (sagemaker.network.NetworkConfig): A NetworkConfig
10441108
object that configures network isolation, encryption of
10451109
inter-container traffic, security group IDs, and subnets.
1110+
batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to
1111+
run the monitoring schedule on the batch transform
10461112
"""
10471113
valid_args = {
10481114
arg: value for arg, value in locals().items() if arg != "self" and value is not None
@@ -1052,6 +1118,15 @@ def update_monitoring_schedule(
10521118
if len(valid_args) <= 0:
10531119
raise ValueError("Nothing to update.")
10541120

1121+
if batch_transform_input is not None and endpoint_input is not None:
1122+
message = (
1123+
"Need to have either batch_transform_input or endpoint_input to create an "
1124+
"Amazon Model Monitoring Schedule. "
1125+
"Please provide only one of the above required inputs"
1126+
)
1127+
_LOGGER.error(message)
1128+
raise ValueError(message)
1129+
10551130
# Only need to update schedule expression
10561131
if len(valid_args) == 1 and schedule_cron_expression is not None:
10571132
self._update_monitoring_schedule(self.job_definition_name, schedule_cron_expression)
@@ -1084,6 +1159,7 @@ def update_monitoring_schedule(
10841159
env=env,
10851160
tags=self.tags,
10861161
network_config=network_config,
1162+
batch_transform_input=batch_transform_input,
10871163
)
10881164
self.sagemaker_session.sagemaker_client.create_model_explainability_job_definition(
10891165
**request_dict

src/sagemaker/model_monitor/dataset_format.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,46 @@ def sagemaker_capture_json():
5959
6060
"""
6161
return {"sagemakerCaptureJson": {}}
62+
63+
64+
class MonitoringDatasetFormat(object):
65+
"""Represents a Dataset Format that is used when calling a DefaultModelMonitor."""
66+
67+
@staticmethod
68+
def csv(header=True):
69+
"""Returns a DatasetFormat JSON string for use with a DefaultModelMonitor.
70+
71+
Args:
72+
header (bool): Whether the csv dataset to baseline and monitor has a header.
73+
Default: True.
74+
output_columns_position (str): The position of the output columns.
75+
Must be one of ("START", "END"). Default: "START".
76+
77+
Returns:
78+
dict: JSON string containing DatasetFormat to be used by DefaultModelMonitor.
79+
80+
"""
81+
return {"Csv": {"Header": header}}
82+
83+
@staticmethod
84+
def json(lines=True):
85+
"""Returns a DatasetFormat JSON string for use with a DefaultModelMonitor.
86+
87+
Args:
88+
lines (bool): Whether the file should be read as a json object per line. Default: True.
89+
90+
Returns:
91+
dict: JSON string containing DatasetFormat to be used by DefaultModelMonitor.
92+
93+
"""
94+
return {"Json": {"Line": lines}}
95+
96+
@staticmethod
97+
def parquet():
98+
"""Returns a DatasetFormat SageMaker Capture Json string for use with a DefaultModelMonitor.
99+
100+
Returns:
101+
dict: JSON string containing DatasetFormat to be used by DefaultModelMonitor.
102+
103+
"""
104+
return {"Parquet": {}}

0 commit comments

Comments
 (0)