@@ -227,6 +227,7 @@ def _build_create_job_definition_request(
227
227
env = None ,
228
228
tags = None ,
229
229
network_config = None ,
230
+ batch_transform_input = None ,
230
231
):
231
232
"""Build the request for job definition creation API
232
233
@@ -270,6 +271,8 @@ def _build_create_job_definition_request(
270
271
network_config (sagemaker.network.NetworkConfig): A NetworkConfig
271
272
object that configures network isolation, encryption of
272
273
inter-container traffic, security group IDs, and subnets.
274
+ batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to run
275
+ the monitoring schedule on the batch transform
273
276
274
277
Returns:
275
278
dict: request parameters to create job definition.
@@ -366,6 +369,27 @@ def _build_create_job_definition_request(
366
369
latest_baselining_job_config .probability_threshold_attribute
367
370
)
368
371
job_input = normalized_endpoint_input ._to_request_dict ()
372
+ elif batch_transform_input is not None :
373
+ # backfill attributes to batch transform input
374
+ if latest_baselining_job_config is not None :
375
+ if batch_transform_input .features_attribute is None :
376
+ batch_transform_input .features_attribute = (
377
+ latest_baselining_job_config .features_attribute
378
+ )
379
+ if batch_transform_input .inference_attribute is None :
380
+ batch_transform_input .inference_attribute = (
381
+ latest_baselining_job_config .inference_attribute
382
+ )
383
+ if batch_transform_input .probability_attribute is None :
384
+ batch_transform_input .probability_attribute = (
385
+ latest_baselining_job_config .probability_attribute
386
+ )
387
+ if batch_transform_input .probability_threshold_attribute is None :
388
+ batch_transform_input .probability_threshold_attribute = (
389
+ latest_baselining_job_config .probability_threshold_attribute
390
+ )
391
+ job_input = batch_transform_input ._to_request_dict ()
392
+
369
393
if ground_truth_input is not None :
370
394
job_input ["GroundTruthS3Input" ] = dict (S3Uri = ground_truth_input )
371
395
@@ -500,37 +524,46 @@ def suggest_baseline(
500
524
# noinspection PyMethodOverriding
501
525
def create_monitoring_schedule (
502
526
self ,
503
- endpoint_input ,
504
- ground_truth_input ,
527
+ endpoint_input = None ,
528
+ ground_truth_input = None ,
505
529
analysis_config = None ,
506
530
output_s3_uri = None ,
507
531
constraints = None ,
508
532
monitor_schedule_name = None ,
509
533
schedule_cron_expression = None ,
510
534
enable_cloudwatch_metrics = True ,
535
+ batch_transform_input = None ,
511
536
):
512
537
"""Creates a monitoring schedule.
513
538
514
539
Args:
515
540
endpoint_input (str or sagemaker.model_monitor.EndpointInput): The endpoint to monitor.
516
- This can either be the endpoint name or an EndpointInput.
517
- ground_truth_input (str): S3 URI to ground truth dataset.
541
+ This can either be the endpoint name or an EndpointInput. (default: None)
542
+ ground_truth_input (str): S3 URI to ground truth dataset. (default: None)
518
543
analysis_config (str or BiasAnalysisConfig): URI to analysis_config for the bias job.
519
544
If it is None then configuration of the latest baselining job will be reused, but
520
- if no baselining job then fail the call.
545
+ if no baselining job then fail the call. (default: None)
521
546
output_s3_uri (str): S3 destination of the constraint_violations and analysis result.
522
- Default: "s3://<default_session_bucket>/<job_name>/output"
547
+ Default: "s3://<default_session_bucket>/<job_name>/output" (default: None)
523
548
constraints (sagemaker.model_monitor.Constraints or str): If provided it will be used
524
549
for monitoring the endpoint. It can be a Constraints object or an S3 uri pointing
525
- to a constraints JSON file.
550
+ to a constraints JSON file. (default: None)
526
551
monitor_schedule_name (str): Schedule name. If not specified, the processor generates
527
552
a default job name, based on the image name and current timestamp.
553
+ (default: None)
528
554
schedule_cron_expression (str): The cron expression that dictates the frequency that
529
555
this job run. See sagemaker.model_monitor.CronExpressionGenerator for valid
530
- expressions. Default: Daily.
556
+ expressions. Default: Daily. (default: None)
531
557
enable_cloudwatch_metrics (bool): Whether to publish cloudwatch metrics as part of
532
- the baselining or monitoring jobs.
558
+ the baselining or monitoring jobs. (default: True)
559
+ batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to run
560
+ the monitoring schedule on the batch transform (default: None)
533
561
"""
562
+ # we default ground_truth_input to None in the function signature
563
+ # but verify they are giving here for positional argument
564
+ # backward compatibility reason.
565
+ if not ground_truth_input :
566
+ raise ValueError ("ground_truth_input can not be None." )
534
567
if self .job_definition_name is not None or self .monitoring_schedule_name is not None :
535
568
message = (
536
569
"It seems that this object was already used to create an Amazon Model "
@@ -540,6 +573,15 @@ def create_monitoring_schedule(
540
573
_LOGGER .error (message )
541
574
raise ValueError (message )
542
575
576
+ if (batch_transform_input is not None ) ^ (endpoint_input is None ):
577
+ message = (
578
+ "Need to have either batch_transform_input or endpoint_input to create an "
579
+ "Amazon Model Monitoring Schedule. "
580
+ "Please provide only one of the above required inputs"
581
+ )
582
+ _LOGGER .error (message )
583
+ raise ValueError (message )
584
+
543
585
# create job definition
544
586
monitor_schedule_name = self ._generate_monitoring_schedule_name (
545
587
schedule_name = monitor_schedule_name
@@ -569,6 +611,7 @@ def create_monitoring_schedule(
569
611
env = self .env ,
570
612
tags = self .tags ,
571
613
network_config = self .network_config ,
614
+ batch_transform_input = batch_transform_input ,
572
615
)
573
616
self .sagemaker_session .sagemaker_client .create_model_bias_job_definition (** request_dict )
574
617
@@ -612,6 +655,7 @@ def update_monitoring_schedule(
612
655
max_runtime_in_seconds = None ,
613
656
env = None ,
614
657
network_config = None ,
658
+ batch_transform_input = None ,
615
659
):
616
660
"""Updates the existing monitoring schedule.
617
661
@@ -651,6 +695,8 @@ def update_monitoring_schedule(
651
695
network_config (sagemaker.network.NetworkConfig): A NetworkConfig
652
696
object that configures network isolation, encryption of
653
697
inter-container traffic, security group IDs, and subnets.
698
+ batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to run
699
+ the monitoring schedule on the batch transform
654
700
"""
655
701
valid_args = {
656
702
arg : value for arg , value in locals ().items () if arg != "self" and value is not None
@@ -660,6 +706,15 @@ def update_monitoring_schedule(
660
706
if len (valid_args ) <= 0 :
661
707
return
662
708
709
+ if batch_transform_input is not None and endpoint_input is not None :
710
+ message = (
711
+ "Need to have either batch_transform_input or endpoint_input to create an "
712
+ "Amazon Model Monitoring Schedule. "
713
+ "Please provide only one of the above required inputs"
714
+ )
715
+ _LOGGER .error (message )
716
+ raise ValueError (message )
717
+
663
718
# Only need to update schedule expression
664
719
if len (valid_args ) == 1 and schedule_cron_expression is not None :
665
720
self ._update_monitoring_schedule (self .job_definition_name , schedule_cron_expression )
@@ -691,6 +746,7 @@ def update_monitoring_schedule(
691
746
env = env ,
692
747
tags = self .tags ,
693
748
network_config = network_config ,
749
+ batch_transform_input = batch_transform_input ,
694
750
)
695
751
self .sagemaker_session .sagemaker_client .create_model_bias_job_definition (** request_dict )
696
752
try :
@@ -895,19 +951,20 @@ def suggest_baseline(
895
951
# noinspection PyMethodOverriding
896
952
def create_monitoring_schedule (
897
953
self ,
898
- endpoint_input ,
954
+ endpoint_input = None ,
899
955
analysis_config = None ,
900
956
output_s3_uri = None ,
901
957
constraints = None ,
902
958
monitor_schedule_name = None ,
903
959
schedule_cron_expression = None ,
904
960
enable_cloudwatch_metrics = True ,
961
+ batch_transform_input = None ,
905
962
):
906
963
"""Creates a monitoring schedule.
907
964
908
965
Args:
909
966
endpoint_input (str or sagemaker.model_monitor.EndpointInput): The endpoint to monitor.
910
- This can either be the endpoint name or an EndpointInput.
967
+ This can either be the endpoint name or an EndpointInput. (default: None)
911
968
analysis_config (str or ExplainabilityAnalysisConfig): URI to the analysis_config for
912
969
the explainability job. If it is None then configuration of the latest baselining
913
970
job will be reused, but if no baselining job then fail the call.
@@ -923,6 +980,8 @@ def create_monitoring_schedule(
923
980
expressions. Default: Daily.
924
981
enable_cloudwatch_metrics (bool): Whether to publish cloudwatch metrics as part of
925
982
the baselining or monitoring jobs.
983
+ batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to
984
+ run the monitoring schedule on the batch transform
926
985
"""
927
986
if self .job_definition_name is not None or self .monitoring_schedule_name is not None :
928
987
message = (
@@ -933,6 +992,15 @@ def create_monitoring_schedule(
933
992
_LOGGER .error (message )
934
993
raise ValueError (message )
935
994
995
+ if (batch_transform_input is not None ) ^ (endpoint_input is None ):
996
+ message = (
997
+ "Need to have either batch_transform_input or endpoint_input to create an "
998
+ "Amazon Model Monitoring Schedule."
999
+ "Please provide only one of the above required inputs"
1000
+ )
1001
+ _LOGGER .error (message )
1002
+ raise ValueError (message )
1003
+
936
1004
# create job definition
937
1005
monitor_schedule_name = self ._generate_monitoring_schedule_name (
938
1006
schedule_name = monitor_schedule_name
@@ -961,6 +1029,7 @@ def create_monitoring_schedule(
961
1029
env = self .env ,
962
1030
tags = self .tags ,
963
1031
network_config = self .network_config ,
1032
+ batch_transform_input = batch_transform_input ,
964
1033
)
965
1034
self .sagemaker_session .sagemaker_client .create_model_explainability_job_definition (
966
1035
** request_dict
@@ -1005,6 +1074,7 @@ def update_monitoring_schedule(
1005
1074
max_runtime_in_seconds = None ,
1006
1075
env = None ,
1007
1076
network_config = None ,
1077
+ batch_transform_input = None ,
1008
1078
):
1009
1079
"""Updates the existing monitoring schedule.
1010
1080
@@ -1043,6 +1113,8 @@ def update_monitoring_schedule(
1043
1113
network_config (sagemaker.network.NetworkConfig): A NetworkConfig
1044
1114
object that configures network isolation, encryption of
1045
1115
inter-container traffic, security group IDs, and subnets.
1116
+ batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to
1117
+ run the monitoring schedule on the batch transform
1046
1118
"""
1047
1119
valid_args = {
1048
1120
arg : value for arg , value in locals ().items () if arg != "self" and value is not None
@@ -1052,6 +1124,15 @@ def update_monitoring_schedule(
1052
1124
if len (valid_args ) <= 0 :
1053
1125
raise ValueError ("Nothing to update." )
1054
1126
1127
+ if batch_transform_input is not None and endpoint_input is not None :
1128
+ message = (
1129
+ "Need to have either batch_transform_input or endpoint_input to create an "
1130
+ "Amazon Model Monitoring Schedule. "
1131
+ "Please provide only one of the above required inputs"
1132
+ )
1133
+ _LOGGER .error (message )
1134
+ raise ValueError (message )
1135
+
1055
1136
# Only need to update schedule expression
1056
1137
if len (valid_args ) == 1 and schedule_cron_expression is not None :
1057
1138
self ._update_monitoring_schedule (self .job_definition_name , schedule_cron_expression )
@@ -1084,6 +1165,7 @@ def update_monitoring_schedule(
1084
1165
env = env ,
1085
1166
tags = self .tags ,
1086
1167
network_config = network_config ,
1168
+ batch_transform_input = batch_transform_input ,
1087
1169
)
1088
1170
self .sagemaker_session .sagemaker_client .create_model_explainability_job_definition (
1089
1171
** request_dict
0 commit comments