@@ -227,6 +227,7 @@ def _build_create_job_definition_request(
227
227
env = None ,
228
228
tags = None ,
229
229
network_config = None ,
230
+ batch_transform_input = None ,
230
231
):
231
232
"""Build the request for job definition creation API
232
233
@@ -270,6 +271,8 @@ def _build_create_job_definition_request(
270
271
network_config (sagemaker.network.NetworkConfig): A NetworkConfig
271
272
object that configures network isolation, encryption of
272
273
inter-container traffic, security group IDs, and subnets.
274
+ batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to run
275
+ the monitoring schedule on the batch transform
273
276
274
277
Returns:
275
278
dict: request parameters to create job definition.
@@ -366,6 +369,27 @@ def _build_create_job_definition_request(
366
369
latest_baselining_job_config .probability_threshold_attribute
367
370
)
368
371
job_input = normalized_endpoint_input ._to_request_dict ()
372
+ elif batch_transform_input is not None :
373
+ # backfill attributes to batch transform input
374
+ if latest_baselining_job_config is not None :
375
+ if batch_transform_input .features_attribute is None :
376
+ batch_transform_input .features_attribute = (
377
+ latest_baselining_job_config .features_attribute
378
+ )
379
+ if batch_transform_input .inference_attribute is None :
380
+ batch_transform_input .inference_attribute = (
381
+ latest_baselining_job_config .inference_attribute
382
+ )
383
+ if batch_transform_input .probability_attribute is None :
384
+ batch_transform_input .probability_attribute = (
385
+ latest_baselining_job_config .probability_attribute
386
+ )
387
+ if batch_transform_input .probability_threshold_attribute is None :
388
+ batch_transform_input .probability_threshold_attribute = (
389
+ latest_baselining_job_config .probability_threshold_attribute
390
+ )
391
+ job_input = batch_transform_input ._to_request_dict ()
392
+
369
393
if ground_truth_input is not None :
370
394
job_input ["GroundTruthS3Input" ] = dict (S3Uri = ground_truth_input )
371
395
@@ -500,14 +524,15 @@ def suggest_baseline(
500
524
# noinspection PyMethodOverriding
501
525
def create_monitoring_schedule (
502
526
self ,
503
- endpoint_input ,
504
527
ground_truth_input ,
528
+ endpoint_input = None ,
505
529
analysis_config = None ,
506
530
output_s3_uri = None ,
507
531
constraints = None ,
508
532
monitor_schedule_name = None ,
509
533
schedule_cron_expression = None ,
510
534
enable_cloudwatch_metrics = True ,
535
+ batch_transform_input = None ,
511
536
):
512
537
"""Creates a monitoring schedule.
513
538
@@ -530,6 +555,8 @@ def create_monitoring_schedule(
530
555
expressions. Default: Daily.
531
556
enable_cloudwatch_metrics (bool): Whether to publish cloudwatch metrics as part of
532
557
the baselining or monitoring jobs.
558
+ batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to run
559
+ the monitoring schedule on the batch transform
533
560
"""
534
561
if self .job_definition_name is not None or self .monitoring_schedule_name is not None :
535
562
message = (
@@ -540,6 +567,15 @@ def create_monitoring_schedule(
540
567
_LOGGER .error (message )
541
568
raise ValueError (message )
542
569
570
+ if (batch_transform_input is not None ) ^ (endpoint_input is None ):
571
+ message = (
572
+ "Need to have either batch_transform_input or endpoint_input to create an "
573
+ "Amazon Model Monitoring Schedule. "
574
+ "Please provide only one of the above required inputs"
575
+ )
576
+ _LOGGER .error (message )
577
+ raise ValueError (message )
578
+
543
579
# create job definition
544
580
monitor_schedule_name = self ._generate_monitoring_schedule_name (
545
581
schedule_name = monitor_schedule_name
@@ -569,6 +605,7 @@ def create_monitoring_schedule(
569
605
env = self .env ,
570
606
tags = self .tags ,
571
607
network_config = self .network_config ,
608
+ batch_transform_input = batch_transform_input ,
572
609
)
573
610
self .sagemaker_session .sagemaker_client .create_model_bias_job_definition (** request_dict )
574
611
@@ -612,6 +649,7 @@ def update_monitoring_schedule(
612
649
max_runtime_in_seconds = None ,
613
650
env = None ,
614
651
network_config = None ,
652
+ batch_transform_input = None ,
615
653
):
616
654
"""Updates the existing monitoring schedule.
617
655
@@ -651,6 +689,8 @@ def update_monitoring_schedule(
651
689
network_config (sagemaker.network.NetworkConfig): A NetworkConfig
652
690
object that configures network isolation, encryption of
653
691
inter-container traffic, security group IDs, and subnets.
692
+ batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to run
693
+ the monitoring schedule on the batch transform
654
694
"""
655
695
valid_args = {
656
696
arg : value for arg , value in locals ().items () if arg != "self" and value is not None
@@ -660,6 +700,15 @@ def update_monitoring_schedule(
660
700
if len (valid_args ) <= 0 :
661
701
return
662
702
703
+ if batch_transform_input is not None and endpoint_input is not None :
704
+ message = (
705
+ "Need to have either batch_transform_input or endpoint_input to create an "
706
+ "Amazon Model Monitoring Schedule. "
707
+ "Please provide only one of the above required inputs"
708
+ )
709
+ _LOGGER .error (message )
710
+ raise ValueError (message )
711
+
663
712
# Only need to update schedule expression
664
713
if len (valid_args ) == 1 and schedule_cron_expression is not None :
665
714
self ._update_monitoring_schedule (self .job_definition_name , schedule_cron_expression )
@@ -691,6 +740,7 @@ def update_monitoring_schedule(
691
740
env = env ,
692
741
tags = self .tags ,
693
742
network_config = network_config ,
743
+ batch_transform_input = batch_transform_input ,
694
744
)
695
745
self .sagemaker_session .sagemaker_client .create_model_bias_job_definition (** request_dict )
696
746
try :
@@ -895,13 +945,14 @@ def suggest_baseline(
895
945
# noinspection PyMethodOverriding
896
946
def create_monitoring_schedule (
897
947
self ,
898
- endpoint_input ,
948
+ endpoint_input = None ,
899
949
analysis_config = None ,
900
950
output_s3_uri = None ,
901
951
constraints = None ,
902
952
monitor_schedule_name = None ,
903
953
schedule_cron_expression = None ,
904
954
enable_cloudwatch_metrics = True ,
955
+ batch_transform_input = None ,
905
956
):
906
957
"""Creates a monitoring schedule.
907
958
@@ -923,6 +974,8 @@ def create_monitoring_schedule(
923
974
expressions. Default: Daily.
924
975
enable_cloudwatch_metrics (bool): Whether to publish cloudwatch metrics as part of
925
976
the baselining or monitoring jobs.
977
+ batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to
978
+ run the monitoring schedule on the batch transform
926
979
"""
927
980
if self .job_definition_name is not None or self .monitoring_schedule_name is not None :
928
981
message = (
@@ -933,6 +986,15 @@ def create_monitoring_schedule(
933
986
_LOGGER .error (message )
934
987
raise ValueError (message )
935
988
989
+ if (batch_transform_input is not None ) ^ (endpoint_input is None ):
990
+ message = (
991
+ "Need to have either batch_transform_input or endpoint_input to create an "
992
+ "Amazon Model Monitoring Schedule."
993
+ "Please provide only one of the above required inputs"
994
+ )
995
+ _LOGGER .error (message )
996
+ raise ValueError (message )
997
+
936
998
# create job definition
937
999
monitor_schedule_name = self ._generate_monitoring_schedule_name (
938
1000
schedule_name = monitor_schedule_name
@@ -961,6 +1023,7 @@ def create_monitoring_schedule(
961
1023
env = self .env ,
962
1024
tags = self .tags ,
963
1025
network_config = self .network_config ,
1026
+ batch_transform_input = batch_transform_input ,
964
1027
)
965
1028
self .sagemaker_session .sagemaker_client .create_model_explainability_job_definition (
966
1029
** request_dict
@@ -1005,6 +1068,7 @@ def update_monitoring_schedule(
1005
1068
max_runtime_in_seconds = None ,
1006
1069
env = None ,
1007
1070
network_config = None ,
1071
+ batch_transform_input = None ,
1008
1072
):
1009
1073
"""Updates the existing monitoring schedule.
1010
1074
@@ -1043,6 +1107,8 @@ def update_monitoring_schedule(
1043
1107
network_config (sagemaker.network.NetworkConfig): A NetworkConfig
1044
1108
object that configures network isolation, encryption of
1045
1109
inter-container traffic, security group IDs, and subnets.
1110
+ batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to
1111
+ run the monitoring schedule on the batch transform
1046
1112
"""
1047
1113
valid_args = {
1048
1114
arg : value for arg , value in locals ().items () if arg != "self" and value is not None
@@ -1052,6 +1118,15 @@ def update_monitoring_schedule(
1052
1118
if len (valid_args ) <= 0 :
1053
1119
raise ValueError ("Nothing to update." )
1054
1120
1121
+ if batch_transform_input is not None and endpoint_input is not None :
1122
+ message = (
1123
+ "Need to have either batch_transform_input or endpoint_input to create an "
1124
+ "Amazon Model Monitoring Schedule. "
1125
+ "Please provide only one of the above required inputs"
1126
+ )
1127
+ _LOGGER .error (message )
1128
+ raise ValueError (message )
1129
+
1055
1130
# Only need to update schedule expression
1056
1131
if len (valid_args ) == 1 and schedule_cron_expression is not None :
1057
1132
self ._update_monitoring_schedule (self .job_definition_name , schedule_cron_expression )
@@ -1084,6 +1159,7 @@ def update_monitoring_schedule(
1084
1159
env = env ,
1085
1160
tags = self .tags ,
1086
1161
network_config = network_config ,
1162
+ batch_transform_input = batch_transform_input ,
1087
1163
)
1088
1164
self .sagemaker_session .sagemaker_client .create_model_explainability_job_definition (
1089
1165
** request_dict
0 commit comments