Skip to content

Commit 7669263

Browse files
keshav-chandakKeshav Chandakjerrypeng7773
authored
feature: added support for batch transform with model monitoring (#3418)
Co-authored-by: Keshav Chandak <[email protected]> Co-authored-by: jerrypeng7773 <[email protected]>
1 parent b267f1e commit 7669263

14 files changed

+1267
-30
lines changed

src/sagemaker/inputs.py

+43
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ class TransformInput(object):
176176
output_filter: str = attr.ib(default=None)
177177
join_source: str = attr.ib(default=None)
178178
model_client_config: dict = attr.ib(default=None)
179+
batch_data_capture_config: dict = attr.ib(default=None)
179180

180181

181182
class FileSystemInput(object):
@@ -232,3 +233,45 @@ def __init__(
232233

233234
if content_type:
234235
self.config["ContentType"] = content_type
236+
237+
238+
class BatchDataCaptureConfig(object):
239+
"""Configuration object passed in when create a batch transform job.
240+
241+
Specifies configuration related to batch transform job data capture for use with
242+
Amazon SageMaker Model Monitoring
243+
"""
244+
245+
def __init__(
246+
self,
247+
destination_s3_uri: str,
248+
kms_key_id: str = None,
249+
generate_inference_id: bool = None,
250+
):
251+
"""Create new BatchDataCaptureConfig
252+
253+
Args:
254+
destination_s3_uri (str): S3 Location to store the captured data
255+
kms_key_id (str): The KMS key to use when writing to S3.
256+
KmsKeyId can be an ID of a KMS key, ARN of a KMS key, alias of a KMS key,
257+
or alias of a KMS key. The KmsKeyId is applied to all outputs.
258+
(default: None)
259+
generate_inference_id (bool): Flag to generate an inference id
260+
(default: None)
261+
"""
262+
self.destination_s3_uri = destination_s3_uri
263+
self.kms_key_id = kms_key_id
264+
self.generate_inference_id = generate_inference_id
265+
266+
def _to_request_dict(self):
267+
"""Generates a request dictionary using the parameters provided to the class."""
268+
batch_data_capture_config = {
269+
"DestinationS3Uri": self.destination_s3_uri,
270+
}
271+
272+
if self.kms_key_id is not None:
273+
batch_data_capture_config["KmsKeyId"] = self.kms_key_id
274+
if self.generate_inference_id is not None:
275+
batch_data_capture_config["GenerateInferenceId"] = self.generate_inference_id
276+
277+
return batch_data_capture_config

src/sagemaker/model_monitor/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from sagemaker.model_monitor.model_monitoring import BaseliningJob # noqa: F401
2424
from sagemaker.model_monitor.model_monitoring import MonitoringExecution # noqa: F401
2525
from sagemaker.model_monitor.model_monitoring import EndpointInput # noqa: F401
26+
from sagemaker.model_monitor.model_monitoring import BatchTransformInput # noqa: F401
2627
from sagemaker.model_monitor.model_monitoring import MonitoringOutput # noqa: F401
2728
from sagemaker.model_monitor.model_monitoring import ModelQualityMonitor # noqa: F401
2829

@@ -42,5 +43,6 @@
4243

4344
from sagemaker.model_monitor.data_capture_config import DataCaptureConfig # noqa: F401
4445
from sagemaker.model_monitor.dataset_format import DatasetFormat # noqa: F401
46+
from sagemaker.model_monitor.dataset_format import MonitoringDatasetFormat # noqa: F401
4547

4648
from sagemaker.network import NetworkConfig # noqa: F401

src/sagemaker/model_monitor/clarify_model_monitoring.py

+93-11
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ def _build_create_job_definition_request(
227227
env=None,
228228
tags=None,
229229
network_config=None,
230+
batch_transform_input=None,
230231
):
231232
"""Build the request for job definition creation API
232233
@@ -270,6 +271,8 @@ def _build_create_job_definition_request(
270271
network_config (sagemaker.network.NetworkConfig): A NetworkConfig
271272
object that configures network isolation, encryption of
272273
inter-container traffic, security group IDs, and subnets.
274+
batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to run
275+
the monitoring schedule on the batch transform
273276
274277
Returns:
275278
dict: request parameters to create job definition.
@@ -366,6 +369,27 @@ def _build_create_job_definition_request(
366369
latest_baselining_job_config.probability_threshold_attribute
367370
)
368371
job_input = normalized_endpoint_input._to_request_dict()
372+
elif batch_transform_input is not None:
373+
# backfill attributes to batch transform input
374+
if latest_baselining_job_config is not None:
375+
if batch_transform_input.features_attribute is None:
376+
batch_transform_input.features_attribute = (
377+
latest_baselining_job_config.features_attribute
378+
)
379+
if batch_transform_input.inference_attribute is None:
380+
batch_transform_input.inference_attribute = (
381+
latest_baselining_job_config.inference_attribute
382+
)
383+
if batch_transform_input.probability_attribute is None:
384+
batch_transform_input.probability_attribute = (
385+
latest_baselining_job_config.probability_attribute
386+
)
387+
if batch_transform_input.probability_threshold_attribute is None:
388+
batch_transform_input.probability_threshold_attribute = (
389+
latest_baselining_job_config.probability_threshold_attribute
390+
)
391+
job_input = batch_transform_input._to_request_dict()
392+
369393
if ground_truth_input is not None:
370394
job_input["GroundTruthS3Input"] = dict(S3Uri=ground_truth_input)
371395

@@ -500,37 +524,46 @@ def suggest_baseline(
500524
# noinspection PyMethodOverriding
501525
def create_monitoring_schedule(
502526
self,
503-
endpoint_input,
504-
ground_truth_input,
527+
endpoint_input=None,
528+
ground_truth_input=None,
505529
analysis_config=None,
506530
output_s3_uri=None,
507531
constraints=None,
508532
monitor_schedule_name=None,
509533
schedule_cron_expression=None,
510534
enable_cloudwatch_metrics=True,
535+
batch_transform_input=None,
511536
):
512537
"""Creates a monitoring schedule.
513538
514539
Args:
515540
endpoint_input (str or sagemaker.model_monitor.EndpointInput): The endpoint to monitor.
516-
This can either be the endpoint name or an EndpointInput.
517-
ground_truth_input (str): S3 URI to ground truth dataset.
541+
This can either be the endpoint name or an EndpointInput. (default: None)
542+
ground_truth_input (str): S3 URI to ground truth dataset. (default: None)
518543
analysis_config (str or BiasAnalysisConfig): URI to analysis_config for the bias job.
519544
If it is None then configuration of the latest baselining job will be reused, but
520-
if no baselining job then fail the call.
545+
if no baselining job then fail the call. (default: None)
521546
output_s3_uri (str): S3 destination of the constraint_violations and analysis result.
522-
Default: "s3://<default_session_bucket>/<job_name>/output"
547+
Default: "s3://<default_session_bucket>/<job_name>/output" (default: None)
523548
constraints (sagemaker.model_monitor.Constraints or str): If provided it will be used
524549
for monitoring the endpoint. It can be a Constraints object or an S3 uri pointing
525-
to a constraints JSON file.
550+
to a constraints JSON file. (default: None)
526551
monitor_schedule_name (str): Schedule name. If not specified, the processor generates
527552
a default job name, based on the image name and current timestamp.
553+
(default: None)
528554
schedule_cron_expression (str): The cron expression that dictates the frequency that
529555
this job run. See sagemaker.model_monitor.CronExpressionGenerator for valid
530-
expressions. Default: Daily.
556+
expressions. Default: Daily. (default: None)
531557
enable_cloudwatch_metrics (bool): Whether to publish cloudwatch metrics as part of
532-
the baselining or monitoring jobs.
558+
the baselining or monitoring jobs. (default: True)
559+
batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to run
560+
the monitoring schedule on the batch transform (default: None)
533561
"""
562+
# we default ground_truth_input to None in the function signature
563+
# but verify they are giving here for positional argument
564+
# backward compatibility reason.
565+
if not ground_truth_input:
566+
raise ValueError("ground_truth_input can not be None.")
534567
if self.job_definition_name is not None or self.monitoring_schedule_name is not None:
535568
message = (
536569
"It seems that this object was already used to create an Amazon Model "
@@ -540,6 +573,15 @@ def create_monitoring_schedule(
540573
_LOGGER.error(message)
541574
raise ValueError(message)
542575

576+
if (batch_transform_input is not None) ^ (endpoint_input is None):
577+
message = (
578+
"Need to have either batch_transform_input or endpoint_input to create an "
579+
"Amazon Model Monitoring Schedule. "
580+
"Please provide only one of the above required inputs"
581+
)
582+
_LOGGER.error(message)
583+
raise ValueError(message)
584+
543585
# create job definition
544586
monitor_schedule_name = self._generate_monitoring_schedule_name(
545587
schedule_name=monitor_schedule_name
@@ -569,6 +611,7 @@ def create_monitoring_schedule(
569611
env=self.env,
570612
tags=self.tags,
571613
network_config=self.network_config,
614+
batch_transform_input=batch_transform_input,
572615
)
573616
self.sagemaker_session.sagemaker_client.create_model_bias_job_definition(**request_dict)
574617

@@ -612,6 +655,7 @@ def update_monitoring_schedule(
612655
max_runtime_in_seconds=None,
613656
env=None,
614657
network_config=None,
658+
batch_transform_input=None,
615659
):
616660
"""Updates the existing monitoring schedule.
617661
@@ -651,6 +695,8 @@ def update_monitoring_schedule(
651695
network_config (sagemaker.network.NetworkConfig): A NetworkConfig
652696
object that configures network isolation, encryption of
653697
inter-container traffic, security group IDs, and subnets.
698+
batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to run
699+
the monitoring schedule on the batch transform
654700
"""
655701
valid_args = {
656702
arg: value for arg, value in locals().items() if arg != "self" and value is not None
@@ -660,6 +706,15 @@ def update_monitoring_schedule(
660706
if len(valid_args) <= 0:
661707
return
662708

709+
if batch_transform_input is not None and endpoint_input is not None:
710+
message = (
711+
"Need to have either batch_transform_input or endpoint_input to create an "
712+
"Amazon Model Monitoring Schedule. "
713+
"Please provide only one of the above required inputs"
714+
)
715+
_LOGGER.error(message)
716+
raise ValueError(message)
717+
663718
# Only need to update schedule expression
664719
if len(valid_args) == 1 and schedule_cron_expression is not None:
665720
self._update_monitoring_schedule(self.job_definition_name, schedule_cron_expression)
@@ -691,6 +746,7 @@ def update_monitoring_schedule(
691746
env=env,
692747
tags=self.tags,
693748
network_config=network_config,
749+
batch_transform_input=batch_transform_input,
694750
)
695751
self.sagemaker_session.sagemaker_client.create_model_bias_job_definition(**request_dict)
696752
try:
@@ -895,19 +951,20 @@ def suggest_baseline(
895951
# noinspection PyMethodOverriding
896952
def create_monitoring_schedule(
897953
self,
898-
endpoint_input,
954+
endpoint_input=None,
899955
analysis_config=None,
900956
output_s3_uri=None,
901957
constraints=None,
902958
monitor_schedule_name=None,
903959
schedule_cron_expression=None,
904960
enable_cloudwatch_metrics=True,
961+
batch_transform_input=None,
905962
):
906963
"""Creates a monitoring schedule.
907964
908965
Args:
909966
endpoint_input (str or sagemaker.model_monitor.EndpointInput): The endpoint to monitor.
910-
This can either be the endpoint name or an EndpointInput.
967+
This can either be the endpoint name or an EndpointInput. (default: None)
911968
analysis_config (str or ExplainabilityAnalysisConfig): URI to the analysis_config for
912969
the explainability job. If it is None then configuration of the latest baselining
913970
job will be reused, but if no baselining job then fail the call.
@@ -923,6 +980,8 @@ def create_monitoring_schedule(
923980
expressions. Default: Daily.
924981
enable_cloudwatch_metrics (bool): Whether to publish cloudwatch metrics as part of
925982
the baselining or monitoring jobs.
983+
batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to
984+
run the monitoring schedule on the batch transform
926985
"""
927986
if self.job_definition_name is not None or self.monitoring_schedule_name is not None:
928987
message = (
@@ -933,6 +992,15 @@ def create_monitoring_schedule(
933992
_LOGGER.error(message)
934993
raise ValueError(message)
935994

995+
if (batch_transform_input is not None) ^ (endpoint_input is None):
996+
message = (
997+
"Need to have either batch_transform_input or endpoint_input to create an "
998+
"Amazon Model Monitoring Schedule."
999+
"Please provide only one of the above required inputs"
1000+
)
1001+
_LOGGER.error(message)
1002+
raise ValueError(message)
1003+
9361004
# create job definition
9371005
monitor_schedule_name = self._generate_monitoring_schedule_name(
9381006
schedule_name=monitor_schedule_name
@@ -961,6 +1029,7 @@ def create_monitoring_schedule(
9611029
env=self.env,
9621030
tags=self.tags,
9631031
network_config=self.network_config,
1032+
batch_transform_input=batch_transform_input,
9641033
)
9651034
self.sagemaker_session.sagemaker_client.create_model_explainability_job_definition(
9661035
**request_dict
@@ -1005,6 +1074,7 @@ def update_monitoring_schedule(
10051074
max_runtime_in_seconds=None,
10061075
env=None,
10071076
network_config=None,
1077+
batch_transform_input=None,
10081078
):
10091079
"""Updates the existing monitoring schedule.
10101080
@@ -1043,6 +1113,8 @@ def update_monitoring_schedule(
10431113
network_config (sagemaker.network.NetworkConfig): A NetworkConfig
10441114
object that configures network isolation, encryption of
10451115
inter-container traffic, security group IDs, and subnets.
1116+
batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to
1117+
run the monitoring schedule on the batch transform
10461118
"""
10471119
valid_args = {
10481120
arg: value for arg, value in locals().items() if arg != "self" and value is not None
@@ -1052,6 +1124,15 @@ def update_monitoring_schedule(
10521124
if len(valid_args) <= 0:
10531125
raise ValueError("Nothing to update.")
10541126

1127+
if batch_transform_input is not None and endpoint_input is not None:
1128+
message = (
1129+
"Need to have either batch_transform_input or endpoint_input to create an "
1130+
"Amazon Model Monitoring Schedule. "
1131+
"Please provide only one of the above required inputs"
1132+
)
1133+
_LOGGER.error(message)
1134+
raise ValueError(message)
1135+
10551136
# Only need to update schedule expression
10561137
if len(valid_args) == 1 and schedule_cron_expression is not None:
10571138
self._update_monitoring_schedule(self.job_definition_name, schedule_cron_expression)
@@ -1084,6 +1165,7 @@ def update_monitoring_schedule(
10841165
env=env,
10851166
tags=self.tags,
10861167
network_config=network_config,
1168+
batch_transform_input=batch_transform_input,
10871169
)
10881170
self.sagemaker_session.sagemaker_client.create_model_explainability_job_definition(
10891171
**request_dict

src/sagemaker/model_monitor/dataset_format.py

+43
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,46 @@ def sagemaker_capture_json():
5959
6060
"""
6161
return {"sagemakerCaptureJson": {}}
62+
63+
64+
class MonitoringDatasetFormat(object):
65+
"""Represents a Dataset Format that is used when calling a DefaultModelMonitor."""
66+
67+
@staticmethod
68+
def csv(header=True):
69+
"""Returns a DatasetFormat JSON string for use with a DefaultModelMonitor.
70+
71+
Args:
72+
header (bool): Whether the csv dataset to baseline and monitor has a header.
73+
Default: True.
74+
output_columns_position (str): The position of the output columns.
75+
Must be one of ("START", "END"). Default: "START".
76+
77+
Returns:
78+
dict: JSON string containing DatasetFormat to be used by DefaultModelMonitor.
79+
80+
"""
81+
return {"Csv": {"Header": header}}
82+
83+
@staticmethod
84+
def json(lines=True):
85+
"""Returns a DatasetFormat JSON string for use with a DefaultModelMonitor.
86+
87+
Args:
88+
lines (bool): Whether the file should be read as a json object per line. Default: True.
89+
90+
Returns:
91+
dict: JSON string containing DatasetFormat to be used by DefaultModelMonitor.
92+
93+
"""
94+
return {"Json": {"Line": lines}}
95+
96+
@staticmethod
97+
def parquet():
98+
"""Returns a DatasetFormat SageMaker Capture Json string for use with a DefaultModelMonitor.
99+
100+
Returns:
101+
dict: JSON string containing DatasetFormat to be used by DefaultModelMonitor.
102+
103+
"""
104+
return {"Parquet": {}}

0 commit comments

Comments
 (0)