diff --git a/src/sagemaker/inputs.py b/src/sagemaker/inputs.py index 59c7166792..f0c678c623 100644 --- a/src/sagemaker/inputs.py +++ b/src/sagemaker/inputs.py @@ -176,6 +176,7 @@ class TransformInput(object): output_filter: str = attr.ib(default=None) join_source: str = attr.ib(default=None) model_client_config: dict = attr.ib(default=None) + batch_data_capture_config: dict = attr.ib(default=None) class FileSystemInput(object): @@ -232,3 +233,45 @@ def __init__( if content_type: self.config["ContentType"] = content_type + + +class BatchDataCaptureConfig(object): + """Configuration object passed in when create a batch transform job. + + Specifies configuration related to batch transform job data capture for use with + Amazon SageMaker Model Monitoring + """ + + def __init__( + self, + destination_s3_uri: str, + kms_key_id: str = None, + generate_inference_id: bool = None, + ): + """Create new BatchDataCaptureConfig + + Args: + destination_s3_uri (str): S3 Location to store the captured data + kms_key_id (str): The KMS key to use when writing to S3. + KmsKeyId can be an ID of a KMS key, ARN of a KMS key, alias of a KMS key, + or alias of a KMS key. The KmsKeyId is applied to all outputs. + (default: None) + generate_inference_id (bool): Flag to generate an inference id + (default: None) + """ + self.destination_s3_uri = destination_s3_uri + self.kms_key_id = kms_key_id + self.generate_inference_id = generate_inference_id + + def _to_request_dict(self): + """Generates a request dictionary using the parameters provided to the class.""" + batch_data_capture_config = { + "DestinationS3Uri": self.destination_s3_uri, + } + + if self.kms_key_id is not None: + batch_data_capture_config["KmsKeyId"] = self.kms_key_id + if self.generate_inference_id is not None: + batch_data_capture_config["GenerateInferenceId"] = self.generate_inference_id + + return batch_data_capture_config diff --git a/src/sagemaker/model_monitor/__init__.py b/src/sagemaker/model_monitor/__init__.py index 4adbea5b09..2768564be7 100644 --- a/src/sagemaker/model_monitor/__init__.py +++ b/src/sagemaker/model_monitor/__init__.py @@ -23,6 +23,7 @@ from sagemaker.model_monitor.model_monitoring import BaseliningJob # noqa: F401 from sagemaker.model_monitor.model_monitoring import MonitoringExecution # noqa: F401 from sagemaker.model_monitor.model_monitoring import EndpointInput # noqa: F401 +from sagemaker.model_monitor.model_monitoring import BatchTransformInput # noqa: F401 from sagemaker.model_monitor.model_monitoring import MonitoringOutput # noqa: F401 from sagemaker.model_monitor.model_monitoring import ModelQualityMonitor # noqa: F401 @@ -42,5 +43,6 @@ from sagemaker.model_monitor.data_capture_config import DataCaptureConfig # noqa: F401 from sagemaker.model_monitor.dataset_format import DatasetFormat # noqa: F401 +from sagemaker.model_monitor.dataset_format import MonitoringDatasetFormat # noqa: F401 from sagemaker.network import NetworkConfig # noqa: F401 diff --git a/src/sagemaker/model_monitor/clarify_model_monitoring.py b/src/sagemaker/model_monitor/clarify_model_monitoring.py index 477c48e150..1a788a0d53 100644 --- a/src/sagemaker/model_monitor/clarify_model_monitoring.py +++ b/src/sagemaker/model_monitor/clarify_model_monitoring.py @@ -227,6 +227,7 @@ def _build_create_job_definition_request( env=None, tags=None, network_config=None, + batch_transform_input=None, ): """Build the request for job definition creation API @@ -270,6 +271,8 @@ def _build_create_job_definition_request( network_config (sagemaker.network.NetworkConfig): A NetworkConfig object that configures network isolation, encryption of inter-container traffic, security group IDs, and subnets. + batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to run + the monitoring schedule on the batch transform Returns: dict: request parameters to create job definition. @@ -366,6 +369,27 @@ def _build_create_job_definition_request( latest_baselining_job_config.probability_threshold_attribute ) job_input = normalized_endpoint_input._to_request_dict() + elif batch_transform_input is not None: + # backfill attributes to batch transform input + if latest_baselining_job_config is not None: + if batch_transform_input.features_attribute is None: + batch_transform_input.features_attribute = ( + latest_baselining_job_config.features_attribute + ) + if batch_transform_input.inference_attribute is None: + batch_transform_input.inference_attribute = ( + latest_baselining_job_config.inference_attribute + ) + if batch_transform_input.probability_attribute is None: + batch_transform_input.probability_attribute = ( + latest_baselining_job_config.probability_attribute + ) + if batch_transform_input.probability_threshold_attribute is None: + batch_transform_input.probability_threshold_attribute = ( + latest_baselining_job_config.probability_threshold_attribute + ) + job_input = batch_transform_input._to_request_dict() + if ground_truth_input is not None: job_input["GroundTruthS3Input"] = dict(S3Uri=ground_truth_input) @@ -500,37 +524,46 @@ def suggest_baseline( # noinspection PyMethodOverriding def create_monitoring_schedule( self, - endpoint_input, - ground_truth_input, + endpoint_input=None, + ground_truth_input=None, analysis_config=None, output_s3_uri=None, constraints=None, monitor_schedule_name=None, schedule_cron_expression=None, enable_cloudwatch_metrics=True, + batch_transform_input=None, ): """Creates a monitoring schedule. Args: endpoint_input (str or sagemaker.model_monitor.EndpointInput): The endpoint to monitor. - This can either be the endpoint name or an EndpointInput. - ground_truth_input (str): S3 URI to ground truth dataset. + This can either be the endpoint name or an EndpointInput. (default: None) + ground_truth_input (str): S3 URI to ground truth dataset. (default: None) analysis_config (str or BiasAnalysisConfig): URI to analysis_config for the bias job. If it is None then configuration of the latest baselining job will be reused, but - if no baselining job then fail the call. + if no baselining job then fail the call. (default: None) output_s3_uri (str): S3 destination of the constraint_violations and analysis result. - Default: "s3:////output" + Default: "s3:////output" (default: None) constraints (sagemaker.model_monitor.Constraints or str): If provided it will be used for monitoring the endpoint. It can be a Constraints object or an S3 uri pointing - to a constraints JSON file. + to a constraints JSON file. (default: None) monitor_schedule_name (str): Schedule name. If not specified, the processor generates a default job name, based on the image name and current timestamp. + (default: None) schedule_cron_expression (str): The cron expression that dictates the frequency that this job run. See sagemaker.model_monitor.CronExpressionGenerator for valid - expressions. Default: Daily. + expressions. Default: Daily. (default: None) enable_cloudwatch_metrics (bool): Whether to publish cloudwatch metrics as part of - the baselining or monitoring jobs. + the baselining or monitoring jobs. (default: True) + batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to run + the monitoring schedule on the batch transform (default: None) """ + # we default ground_truth_input to None in the function signature + # but verify they are giving here for positional argument + # backward compatibility reason. + if not ground_truth_input: + raise ValueError("ground_truth_input can not be None.") if self.job_definition_name is not None or self.monitoring_schedule_name is not None: message = ( "It seems that this object was already used to create an Amazon Model " @@ -540,6 +573,15 @@ def create_monitoring_schedule( _LOGGER.error(message) raise ValueError(message) + if (batch_transform_input is not None) ^ (endpoint_input is None): + message = ( + "Need to have either batch_transform_input or endpoint_input to create an " + "Amazon Model Monitoring Schedule. " + "Please provide only one of the above required inputs" + ) + _LOGGER.error(message) + raise ValueError(message) + # create job definition monitor_schedule_name = self._generate_monitoring_schedule_name( schedule_name=monitor_schedule_name @@ -569,6 +611,7 @@ def create_monitoring_schedule( env=self.env, tags=self.tags, network_config=self.network_config, + batch_transform_input=batch_transform_input, ) self.sagemaker_session.sagemaker_client.create_model_bias_job_definition(**request_dict) @@ -612,6 +655,7 @@ def update_monitoring_schedule( max_runtime_in_seconds=None, env=None, network_config=None, + batch_transform_input=None, ): """Updates the existing monitoring schedule. @@ -651,6 +695,8 @@ def update_monitoring_schedule( network_config (sagemaker.network.NetworkConfig): A NetworkConfig object that configures network isolation, encryption of inter-container traffic, security group IDs, and subnets. + batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to run + the monitoring schedule on the batch transform """ valid_args = { arg: value for arg, value in locals().items() if arg != "self" and value is not None @@ -660,6 +706,15 @@ def update_monitoring_schedule( if len(valid_args) <= 0: return + if batch_transform_input is not None and endpoint_input is not None: + message = ( + "Need to have either batch_transform_input or endpoint_input to create an " + "Amazon Model Monitoring Schedule. " + "Please provide only one of the above required inputs" + ) + _LOGGER.error(message) + raise ValueError(message) + # Only need to update schedule expression if len(valid_args) == 1 and schedule_cron_expression is not None: self._update_monitoring_schedule(self.job_definition_name, schedule_cron_expression) @@ -691,6 +746,7 @@ def update_monitoring_schedule( env=env, tags=self.tags, network_config=network_config, + batch_transform_input=batch_transform_input, ) self.sagemaker_session.sagemaker_client.create_model_bias_job_definition(**request_dict) try: @@ -895,19 +951,20 @@ def suggest_baseline( # noinspection PyMethodOverriding def create_monitoring_schedule( self, - endpoint_input, + endpoint_input=None, analysis_config=None, output_s3_uri=None, constraints=None, monitor_schedule_name=None, schedule_cron_expression=None, enable_cloudwatch_metrics=True, + batch_transform_input=None, ): """Creates a monitoring schedule. Args: endpoint_input (str or sagemaker.model_monitor.EndpointInput): The endpoint to monitor. - This can either be the endpoint name or an EndpointInput. + This can either be the endpoint name or an EndpointInput. (default: None) analysis_config (str or ExplainabilityAnalysisConfig): URI to the analysis_config for the explainability job. If it is None then configuration of the latest baselining job will be reused, but if no baselining job then fail the call. @@ -923,6 +980,8 @@ def create_monitoring_schedule( expressions. Default: Daily. enable_cloudwatch_metrics (bool): Whether to publish cloudwatch metrics as part of the baselining or monitoring jobs. + batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to + run the monitoring schedule on the batch transform """ if self.job_definition_name is not None or self.monitoring_schedule_name is not None: message = ( @@ -933,6 +992,15 @@ def create_monitoring_schedule( _LOGGER.error(message) raise ValueError(message) + if (batch_transform_input is not None) ^ (endpoint_input is None): + message = ( + "Need to have either batch_transform_input or endpoint_input to create an " + "Amazon Model Monitoring Schedule." + "Please provide only one of the above required inputs" + ) + _LOGGER.error(message) + raise ValueError(message) + # create job definition monitor_schedule_name = self._generate_monitoring_schedule_name( schedule_name=monitor_schedule_name @@ -961,6 +1029,7 @@ def create_monitoring_schedule( env=self.env, tags=self.tags, network_config=self.network_config, + batch_transform_input=batch_transform_input, ) self.sagemaker_session.sagemaker_client.create_model_explainability_job_definition( **request_dict @@ -1005,6 +1074,7 @@ def update_monitoring_schedule( max_runtime_in_seconds=None, env=None, network_config=None, + batch_transform_input=None, ): """Updates the existing monitoring schedule. @@ -1043,6 +1113,8 @@ def update_monitoring_schedule( network_config (sagemaker.network.NetworkConfig): A NetworkConfig object that configures network isolation, encryption of inter-container traffic, security group IDs, and subnets. + batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to + run the monitoring schedule on the batch transform """ valid_args = { arg: value for arg, value in locals().items() if arg != "self" and value is not None @@ -1052,6 +1124,15 @@ def update_monitoring_schedule( if len(valid_args) <= 0: raise ValueError("Nothing to update.") + if batch_transform_input is not None and endpoint_input is not None: + message = ( + "Need to have either batch_transform_input or endpoint_input to create an " + "Amazon Model Monitoring Schedule. " + "Please provide only one of the above required inputs" + ) + _LOGGER.error(message) + raise ValueError(message) + # Only need to update schedule expression if len(valid_args) == 1 and schedule_cron_expression is not None: self._update_monitoring_schedule(self.job_definition_name, schedule_cron_expression) @@ -1084,6 +1165,7 @@ def update_monitoring_schedule( env=env, tags=self.tags, network_config=network_config, + batch_transform_input=batch_transform_input, ) self.sagemaker_session.sagemaker_client.create_model_explainability_job_definition( **request_dict diff --git a/src/sagemaker/model_monitor/dataset_format.py b/src/sagemaker/model_monitor/dataset_format.py index 319cd29ab8..36f868bd83 100644 --- a/src/sagemaker/model_monitor/dataset_format.py +++ b/src/sagemaker/model_monitor/dataset_format.py @@ -59,3 +59,46 @@ def sagemaker_capture_json(): """ return {"sagemakerCaptureJson": {}} + + +class MonitoringDatasetFormat(object): + """Represents a Dataset Format that is used when calling a DefaultModelMonitor.""" + + @staticmethod + def csv(header=True): + """Returns a DatasetFormat JSON string for use with a DefaultModelMonitor. + + Args: + header (bool): Whether the csv dataset to baseline and monitor has a header. + Default: True. + output_columns_position (str): The position of the output columns. + Must be one of ("START", "END"). Default: "START". + + Returns: + dict: JSON string containing DatasetFormat to be used by DefaultModelMonitor. + + """ + return {"Csv": {"Header": header}} + + @staticmethod + def json(lines=True): + """Returns a DatasetFormat JSON string for use with a DefaultModelMonitor. + + Args: + lines (bool): Whether the file should be read as a json object per line. Default: True. + + Returns: + dict: JSON string containing DatasetFormat to be used by DefaultModelMonitor. + + """ + return {"Json": {"Line": lines}} + + @staticmethod + def parquet(): + """Returns a DatasetFormat SageMaker Capture Json string for use with a DefaultModelMonitor. + + Returns: + dict: JSON string containing DatasetFormat to be used by DefaultModelMonitor. + + """ + return {"Parquet": {}} diff --git a/src/sagemaker/model_monitor/model_monitoring.py b/src/sagemaker/model_monitor/model_monitoring.py index 39bdaed7c5..7599da4ff4 100644 --- a/src/sagemaker/model_monitor/model_monitoring.py +++ b/src/sagemaker/model_monitor/model_monitoring.py @@ -23,6 +23,8 @@ import pathlib import logging import uuid +from typing import Union +import attr from six import string_types from six.moves.urllib.parse import urlparse @@ -31,6 +33,7 @@ from sagemaker import image_uris, s3 from sagemaker.exceptions import UnexpectedStatusException from sagemaker.model_monitor.monitoring_files import Constraints, ConstraintViolations, Statistics +from sagemaker.model_monitor.dataset_format import MonitoringDatasetFormat from sagemaker.network import NetworkConfig from sagemaker.processing import Processor, ProcessingInput, ProcessingJob, ProcessingOutput from sagemaker.session import Session @@ -217,12 +220,13 @@ def run_baseline( def create_monitoring_schedule( self, - endpoint_input, - output, + endpoint_input=None, + output=None, statistics=None, constraints=None, monitor_schedule_name=None, schedule_cron_expression=None, + batch_transform_input=None, ): """Creates a monitoring schedule to monitor an Amazon SageMaker Endpoint. @@ -233,22 +237,25 @@ def create_monitoring_schedule( Args: endpoint_input (str or sagemaker.model_monitor.EndpointInput): The endpoint to monitor. - This can either be the endpoint name or an EndpointInput. + This can either be the endpoint name or an EndpointInput. (default: None) output (sagemaker.model_monitor.MonitoringOutput): The output of the monitoring - schedule. + schedule. (default: None) statistics (sagemaker.model_monitor.Statistic or str): If provided alongside constraints, these will be used for monitoring the endpoint. This can be a sagemaker.model_monitor.Statistic object or an S3 uri pointing to a statistic - JSON file. + JSON file. (default: None) constraints (sagemaker.model_monitor.Constraints or str): If provided alongside statistics, these will be used for monitoring the endpoint. This can be a sagemaker.model_monitor.Constraints object or an S3 uri pointing to a constraints - JSON file. + JSON file. (default: None) monitor_schedule_name (str): Schedule name. If not specified, the processor generates - a default job name, based on the image name and current timestamp. + a default job name, based on the image name and current timestamp. (default: None) schedule_cron_expression (str): The cron expression that dictates the frequency that this job runs at. See sagemaker.model_monitor.CronExpressionGenerator for valid - expressions. Default: Daily. + expressions. Default: Daily. (default: None) + batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to + run the monitoring schedule on the batch transform + (default: None) """ if self.monitoring_schedule_name is not None: @@ -260,11 +267,28 @@ def create_monitoring_schedule( print(message) raise ValueError(message) + if not output: + raise ValueError("output can not be None.") + + if (batch_transform_input is not None) ^ (endpoint_input is None): + message = ( + "Need to have either batch_transform_input or endpoint_input to create an " + "Amazon Model Monitoring Schedule. " + "Please provide only one of the above required inputs" + ) + _LOGGER.error(message) + raise ValueError(message) + self.monitoring_schedule_name = self._generate_monitoring_schedule_name( schedule_name=monitor_schedule_name ) - normalized_endpoint_input = self._normalize_endpoint_input(endpoint_input=endpoint_input) + if batch_transform_input is not None: + normalized_monitoring_input = batch_transform_input._to_request_dict() + else: + normalized_monitoring_input = self._normalize_endpoint_input( + endpoint_input=endpoint_input + )._to_request_dict() normalized_monitoring_output = self._normalize_monitoring_output_fields(output=output) @@ -301,7 +325,7 @@ def create_monitoring_schedule( schedule_expression=schedule_cron_expression, statistics_s3_uri=statistics_s3_uri, constraints_s3_uri=constraints_s3_uri, - monitoring_inputs=[normalized_endpoint_input._to_request_dict()], + monitoring_inputs=[normalized_monitoring_input], monitoring_output_config=monitoring_output_config, instance_count=self.instance_count, instance_type=self.instance_type, @@ -1498,7 +1522,7 @@ def suggest_baseline( def create_monitoring_schedule( self, - endpoint_input, + endpoint_input=None, record_preprocessor_script=None, post_analytics_processor_script=None, output_s3_uri=None, @@ -1507,6 +1531,7 @@ def create_monitoring_schedule( monitor_schedule_name=None, schedule_cron_expression=None, enable_cloudwatch_metrics=True, + batch_transform_input=None, ): """Creates a monitoring schedule to monitor an Amazon SageMaker Endpoint. @@ -1517,7 +1542,7 @@ def create_monitoring_schedule( Args: endpoint_input (str or sagemaker.model_monitor.EndpointInput): The endpoint to monitor. - This can either be the endpoint name or an EndpointInput. + This can either be the endpoint name or an EndpointInput. (default: None) record_preprocessor_script (str): The path to the record preprocessor script. This can be a local path or an S3 uri. post_analytics_processor_script (str): The path to the record post-analytics processor @@ -1540,6 +1565,8 @@ def create_monitoring_schedule( expressions. Default: Daily. enable_cloudwatch_metrics (bool): Whether to publish cloudwatch metrics as part of the baselining or monitoring jobs. + batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to + run the monitoring schedule on the batch transform (default: None) """ if self.job_definition_name is not None or self.monitoring_schedule_name is not None: message = ( @@ -1550,6 +1577,15 @@ def create_monitoring_schedule( _LOGGER.error(message) raise ValueError(message) + if (batch_transform_input is not None) ^ (endpoint_input is None): + message = ( + "Need to have either batch_transform_input or endpoint_input to create an " + "Amazon Model Monitoring Schedule. " + "Please provide only one of the above required inputs" + ) + _LOGGER.error(message) + raise ValueError(message) + # create job definition monitor_schedule_name = self._generate_monitoring_schedule_name( schedule_name=monitor_schedule_name @@ -1579,6 +1615,7 @@ def create_monitoring_schedule( env=self.env, tags=self.tags, network_config=self.network_config, + batch_transform_input=batch_transform_input, ) self.sagemaker_session.sagemaker_client.create_data_quality_job_definition(**request_dict) @@ -1804,6 +1841,7 @@ def _update_data_quality_monitoring_schedule( max_runtime_in_seconds=None, env=None, network_config=None, + batch_transform_input=None, ): """Updates the existing monitoring schedule. @@ -1844,6 +1882,8 @@ def _update_data_quality_monitoring_schedule( network_config (sagemaker.network.NetworkConfig): A NetworkConfig object that configures network isolation, encryption of inter-container traffic, security group IDs, and subnets. + batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to + run the monitoring schedule on the batch transform (default: None) """ valid_args = { arg: value for arg, value in locals().items() if arg != "self" and value is not None @@ -1885,6 +1925,7 @@ def _update_data_quality_monitoring_schedule( env=env, tags=self.tags, network_config=network_config, + batch_transform_input=batch_transform_input, ) self.sagemaker_session.sagemaker_client.create_data_quality_job_definition(**request_dict) try: @@ -2132,6 +2173,7 @@ def _build_create_data_quality_job_definition_request( env=None, tags=None, network_config=None, + batch_transform_input=None, ): """Build the request for job definition creation API @@ -2169,6 +2211,8 @@ def _build_create_data_quality_job_definition_request( network_config (sagemaker.network.NetworkConfig): A NetworkConfig object that configures network isolation, encryption of inter-container traffic, security group IDs, and subnets. + batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to + run the monitoring schedule on the batch transform Returns: dict: request parameters to create job definition. @@ -2247,6 +2291,8 @@ def _build_create_data_quality_job_definition_request( endpoint_input=endpoint_input ) job_input = normalized_endpoint_input._to_request_dict() + elif batch_transform_input is not None: + job_input = batch_transform_input._to_request_dict() # job output if output_s3_uri is not None: @@ -2518,9 +2564,9 @@ def suggest_baseline( # noinspection PyMethodOverriding def create_monitoring_schedule( self, - endpoint_input, - ground_truth_input, - problem_type, + endpoint_input=None, + ground_truth_input=None, + problem_type=None, record_preprocessor_script=None, post_analytics_processor_script=None, output_s3_uri=None, @@ -2528,15 +2574,19 @@ def create_monitoring_schedule( monitor_schedule_name=None, schedule_cron_expression=None, enable_cloudwatch_metrics=True, + batch_transform_input=None, ): """Creates a monitoring schedule. Args: endpoint_input (str or sagemaker.model_monitor.EndpointInput): The endpoint to monitor. This can either be the endpoint name or an EndpointInput. + (default: None) ground_truth_input (str): S3 URI to ground truth dataset. + (default: None) problem_type (str): The type of problem of this model quality monitoring. Valid values are "Regression", "BinaryClassification", "MulticlassClassification". + (default: None) record_preprocessor_script (str): The path to the record preprocessor script. This can be a local path or an S3 uri. post_analytics_processor_script (str): The path to the record post-analytics processor @@ -2553,7 +2603,17 @@ def create_monitoring_schedule( expressions. Default: Daily. enable_cloudwatch_metrics (bool): Whether to publish cloudwatch metrics as part of the baselining or monitoring jobs. + batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to + run the monitoring schedule on the batch transform """ + # we default below two parameters to None in the function signature + # but verify they are giving here for positional argument + # backward compatibility reason. + if not ground_truth_input: + raise ValueError("ground_truth_input can not be None.") + if not problem_type: + raise ValueError("problem_type can not be None.") + if self.job_definition_name is not None or self.monitoring_schedule_name is not None: message = ( "It seems that this object was already used to create an Amazon Model " @@ -2563,6 +2623,15 @@ def create_monitoring_schedule( _LOGGER.error(message) raise ValueError(message) + if (batch_transform_input is not None) ^ (endpoint_input is None): + message = ( + "Need to have either batch_transform_input or endpoint_input to create an " + "Amazon Model Monitoring Schedule. " + "Please provide only one of the above required inputs" + ) + _LOGGER.error(message) + raise ValueError(message) + # create job definition monitor_schedule_name = self._generate_monitoring_schedule_name( schedule_name=monitor_schedule_name @@ -2593,6 +2662,7 @@ def create_monitoring_schedule( env=self.env, tags=self.tags, network_config=self.network_config, + batch_transform_input=batch_transform_input, ) self.sagemaker_session.sagemaker_client.create_model_quality_job_definition(**request_dict) @@ -2637,6 +2707,7 @@ def update_monitoring_schedule( max_runtime_in_seconds=None, env=None, network_config=None, + batch_transform_input=None, ): """Updates the existing monitoring schedule. @@ -2679,6 +2750,8 @@ def update_monitoring_schedule( network_config (sagemaker.network.NetworkConfig): A NetworkConfig object that configures network isolation, encryption of inter-container traffic, security group IDs, and subnets. + batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to + run the monitoring schedule on the batch transform """ valid_args = { arg: value for arg, value in locals().items() if arg != "self" and value is not None @@ -2721,6 +2794,7 @@ def update_monitoring_schedule( env=env, tags=self.tags, network_config=network_config, + batch_transform_input=batch_transform_input, ) self.sagemaker_session.sagemaker_client.create_model_quality_job_definition(**request_dict) try: @@ -2832,6 +2906,7 @@ def _build_create_model_quality_job_definition_request( env=None, tags=None, network_config=None, + batch_transform_input=None, ): """Build the request for job definition creation API @@ -2872,6 +2947,8 @@ def _build_create_model_quality_job_definition_request( network_config (sagemaker.network.NetworkConfig): A NetworkConfig object that configures network isolation, encryption of inter-container traffic, security group IDs, and subnets. + batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to + run the monitoring schedule on the batch transform Returns: dict: request parameters to create job definition. @@ -2947,6 +3024,9 @@ def _build_create_model_quality_job_definition_request( endpoint_input=endpoint_input ) job_input = normalized_endpoint_input._to_request_dict() + elif batch_transform_input is not None: + job_input = batch_transform_input._to_request_dict() + if ground_truth_input is not None: job_input["GroundTruthS3Input"] = dict(S3Uri=ground_truth_input) @@ -3382,6 +3462,123 @@ def _to_request_dict(self): return endpoint_input_request +@attr.s +class MonitoringInput(object): + """Accepts parameters specifying batch transform or endpoint inputs for monitoring execution. + + MonitoringInput accepts parameters that specify additional parameters while monitoring jobs. + It also provides a method to turn those parameters into a dictionary. + + Args: + start_time_offset (str): Monitoring start time offset, e.g. "-PT1H" + end_time_offset (str): Monitoring end time offset, e.g. "-PT0H". + features_attribute (str): JSONpath to locate features in JSONlines dataset. + Only used for ModelBiasMonitor and ModelExplainabilityMonitor + inference_attribute (str): Index or JSONpath to locate predicted label(s). + Only used for ModelQualityMonitor, ModelBiasMonitor, and ModelExplainabilityMonitor + probability_attribute (str): Index or JSONpath to locate probabilities. + Only used for ModelQualityMonitor, ModelBiasMonitor and ModelExplainabilityMonitor + probability_threshold_attribute (float): threshold to convert probabilities to binaries + Only used for ModelQualityMonitor, ModelBiasMonitor and ModelExplainabilityMonitor + """ + + start_time_offset: str = attr.ib() + end_time_offset: str = attr.ib() + features_attribute: str = attr.ib() + inference_attribute: str = attr.ib() + probability_attribute: Union[str, int] = attr.ib() + probability_threshold_attribute: float = attr.ib() + + +class BatchTransformInput(MonitoringInput): + """Accepts parameters that specify a batch transform input for monitoring schedule. + + It also provides a method to turn those parameters into a dictionary. + """ + + def __init__( + self, + data_captured_destination_s3_uri: str, + destination: str, + dataset_format: MonitoringDatasetFormat, + s3_input_mode: str = "File", + s3_data_distribution_type: str = "FullyReplicated", + start_time_offset: str = None, + end_time_offset: str = None, + features_attribute: str = None, + inference_attribute: str = None, + probability_attribute: str = None, + probability_threshold_attribute: str = None, + ): + """Initialize a `BatchTransformInput` instance. + + Args: + data_captured_destination_s3_uri (str): Location to the batch transform captured data + file which needs to be analysed. + destination (str): The destination of the input. + s3_input_mode (str): The S3 input mode. Can be one of: "File", "Pipe. (default: File) + s3_data_distribution_type (str): The S3 Data Distribution Type. Can be one of: + "FullyReplicated", "ShardedByS3Key" (default: FullyReplicated) + start_time_offset (str): Monitoring start time offset, e.g. "-PT1H" (default: None) + end_time_offset (str): Monitoring end time offset, e.g. "-PT0H". (default: None) + features_attribute (str): JSONpath to locate features in JSONlines dataset. + Only used for ModelBiasMonitor and ModelExplainabilityMonitor (default: None) + inference_attribute (str): Index or JSONpath to locate predicted label(s). + Only used for ModelQualityMonitor, ModelBiasMonitor, and ModelExplainabilityMonitor + (default: None) + probability_attribute (str): Index or JSONpath to locate probabilities. + Only used for ModelQualityMonitor, ModelBiasMonitor and ModelExplainabilityMonitor + (default: None) + probability_threshold_attribute (float): threshold to convert probabilities to binaries + Only used for ModelQualityMonitor, ModelBiasMonitor and ModelExplainabilityMonitor + (default: None) + + """ + self.data_captured_destination_s3_uri = data_captured_destination_s3_uri + self.destination = destination + self.s3_input_mode = s3_input_mode + self.s3_data_distribution_type = s3_data_distribution_type + self.dataset_format = dataset_format + + super(BatchTransformInput, self).__init__( + start_time_offset=start_time_offset, + end_time_offset=end_time_offset, + features_attribute=features_attribute, + inference_attribute=inference_attribute, + probability_attribute=probability_attribute, + probability_threshold_attribute=probability_threshold_attribute, + ) + + def _to_request_dict(self): + """Generates a request dictionary using the parameters provided to the class.""" + batch_transform_input_data = { + "DataCapturedDestinationS3Uri": self.data_captured_destination_s3_uri, + "LocalPath": self.destination, + "S3InputMode": self.s3_input_mode, + "S3DataDistributionType": self.s3_data_distribution_type, + "DatasetFormat": self.dataset_format, + } + + if self.start_time_offset is not None: + batch_transform_input_data["StartTimeOffset"] = self.start_time_offset + if self.end_time_offset is not None: + batch_transform_input_data["EndTimeOffset"] = self.end_time_offset + if self.features_attribute is not None: + batch_transform_input_data["FeaturesAttribute"] = self.features_attribute + if self.inference_attribute is not None: + batch_transform_input_data["InferenceAttribute"] = self.inference_attribute + if self.probability_attribute is not None: + batch_transform_input_data["ProbabilityAttribute"] = self.probability_attribute + if self.probability_threshold_attribute is not None: + batch_transform_input_data[ + "ProbabilityThresholdAttribute" + ] = self.probability_threshold_attribute + + batch_transform_input_request = {"BatchTransformInput": batch_transform_input_data} + + return batch_transform_input_request + + class MonitoringOutput(object): """Accepts parameters that specify an S3 output for a monitoring job. diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 331bc5cb9f..733263ce0b 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -34,7 +34,7 @@ from sagemaker._studio import _append_project_tags from sagemaker.deprecations import deprecated_class -from sagemaker.inputs import ShuffleConfig, TrainingInput +from sagemaker.inputs import ShuffleConfig, TrainingInput, BatchDataCaptureConfig from sagemaker.user_agent import prepend_user_agent from sagemaker.utils import ( name_from_image, @@ -2454,6 +2454,7 @@ def _get_transform_request( tags, data_processing, model_client_config=None, + batch_data_capture_config: BatchDataCaptureConfig = None, ): """Construct an dict can be used to create an Amazon SageMaker transform job. @@ -2489,6 +2490,9 @@ def _get_transform_request( model_client_config (dict): A dictionary describing the model configuration for the job. Dictionary contains two optional keys, 'InvocationsTimeoutInSeconds', and 'InvocationsMaxRetries'. + batch_data_capture_config (BatchDataCaptureConfig): Configuration object which + specifies the configurations related to the batch data capture for the transform job + (default: None) Returns: Dict: a create transform job request dict @@ -2525,6 +2529,9 @@ def _get_transform_request( if model_client_config and len(model_client_config) > 0: transform_request["ModelClientConfig"] = model_client_config + if batch_data_capture_config is not None: + transform_request["DataCaptureConfig"] = batch_data_capture_config._to_request_dict() + return transform_request def transform( @@ -2542,6 +2549,7 @@ def transform( tags, data_processing, model_client_config=None, + batch_data_capture_config: BatchDataCaptureConfig = None, ): """Create an Amazon SageMaker transform job. @@ -2577,6 +2585,8 @@ def transform( model_client_config (dict): A dictionary describing the model configuration for the job. Dictionary contains two optional keys, 'InvocationsTimeoutInSeconds', and 'InvocationsMaxRetries'. + batch_data_capture_config (BatchDataCaptureConfig): Configuration object which + specifies the configurations related to the batch data capture for the transform job """ tags = _append_project_tags(tags) transform_request = self._get_transform_request( @@ -2593,6 +2603,7 @@ def transform( tags=tags, data_processing=data_processing, model_client_config=model_client_config, + batch_data_capture_config=batch_data_capture_config, ) def submit(request): diff --git a/src/sagemaker/transformer.py b/src/sagemaker/transformer.py index 6b1f51f9d4..a5f2ee75bf 100644 --- a/src/sagemaker/transformer.py +++ b/src/sagemaker/transformer.py @@ -14,11 +14,11 @@ from __future__ import absolute_import from typing import Union, Optional, List, Dict - from botocore import exceptions from sagemaker.job import _Job from sagemaker.session import Session +from sagemaker.inputs import BatchDataCaptureConfig from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.pipeline_context import runnable_by_pipeline from sagemaker.workflow import is_pipeline_variable @@ -127,6 +127,7 @@ def transform( join_source: Optional[Union[str, PipelineVariable]] = None, experiment_config: Optional[Dict[str, str]] = None, model_client_config: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + batch_data_capture_config: BatchDataCaptureConfig = None, wait: bool = True, logs: bool = True, ): @@ -193,6 +194,9 @@ def transform( configuration. Dictionary contains two optional keys, 'InvocationsTimeoutInSeconds', and 'InvocationsMaxRetries'. (default: ``None``). + batch_data_capture_config (BatchDataCaptureConfig): Configuration object which + specifies the configurations related to the batch data capture for the transform job + (default: ``None``). wait (bool): Whether the call should wait until the job completes (default: ``True``). logs (bool): Whether to show the logs produced by the job. @@ -237,6 +241,7 @@ def transform( join_source, experiment_config, model_client_config, + batch_data_capture_config, ) if wait: @@ -372,6 +377,7 @@ def start_new( join_source, experiment_config, model_client_config, + batch_data_capture_config, ): """Placeholder docstring""" @@ -387,6 +393,7 @@ def start_new( join_source, experiment_config, model_client_config, + batch_data_capture_config, ) transformer.sagemaker_session.transform(**transform_args) @@ -407,6 +414,7 @@ def _get_transform_args( join_source, experiment_config, model_client_config, + batch_data_capture_config, ): """Placeholder docstring""" @@ -430,6 +438,7 @@ def _get_transform_args( "model_client_config": model_client_config, "tags": transformer.tags, "data_processing": data_processing, + "batch_data_capture_config": batch_data_capture_config, } ) diff --git a/src/sagemaker/workflow/steps.py b/src/sagemaker/workflow/steps.py index d54320b3cc..2633304a93 100644 --- a/src/sagemaker/workflow/steps.py +++ b/src/sagemaker/workflow/steps.py @@ -710,6 +710,7 @@ def arguments(self) -> RequestType: join_source=self.inputs.join_source, model_client_config=self.inputs.model_client_config, experiment_config=dict(), + batch_data_capture_config=self.inputs.batch_data_capture_config, ) request_dict = self.transformer.sagemaker_session._get_transform_request( **transform_args diff --git a/tests/integ/test_model_monitor.py b/tests/integ/test_model_monitor.py index 4b6d3a39ae..f6d5ee88ed 100644 --- a/tests/integ/test_model_monitor.py +++ b/tests/integ/test_model_monitor.py @@ -26,11 +26,13 @@ from tests.integ import DATA_DIR from sagemaker.model_monitor import DatasetFormat +from sagemaker.model_monitor import MonitoringDatasetFormat from sagemaker.model_monitor import NetworkConfig, Statistics, Constraints from sagemaker.model_monitor import ModelMonitor from sagemaker.model_monitor import DefaultModelMonitor from sagemaker.model_monitor import MonitoringOutput from sagemaker.model_monitor import DataCaptureConfig +from sagemaker.model_monitor import BatchTransformInput from sagemaker.model_monitor.data_capture_config import _MODEL_MONITOR_S3_PATH from sagemaker.model_monitor.data_capture_config import _DATA_CAPTURE_S3_PATH from sagemaker.model_monitor import CronExpressionGenerator @@ -274,6 +276,84 @@ def updated_output_kms_key(sagemaker_session): ) +@pytest.mark.skipif( + tests.integ.test_region() in tests.integ.NO_MODEL_MONITORING_REGIONS, + reason="ModelMonitoring is not yet supported in this region.", +) +@pytest.mark.release +def test_default_monitoring_batch_transform_schedule_name( + sagemaker_session, output_kms_key, volume_kms_key +): + my_default_monitor = DefaultModelMonitor( + role=ROLE, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + volume_size_in_gb=VOLUME_SIZE_IN_GB, + volume_kms_key=volume_kms_key, + output_kms_key=output_kms_key, + max_runtime_in_seconds=MAX_RUNTIME_IN_SECONDS, + sagemaker_session=sagemaker_session, + env=ENVIRONMENT, + tags=TAGS, + network_config=NETWORK_CONFIG, + ) + + output_s3_uri = os.path.join( + "s3://", + sagemaker_session.default_bucket(), + "integ-test-monitoring-output-bucket", + str(uuid.uuid4()), + ) + + data_captured_destination_s3_uri = os.path.join( + "s3://", + sagemaker_session.default_bucket(), + "sagemaker-tensorflow-serving-batch-transform", + str(uuid.uuid4()), + ) + + batch_transform_input = BatchTransformInput( + data_captured_destination_s3_uri=data_captured_destination_s3_uri, + destination="/opt/ml/processing/output", + dataset_format=MonitoringDatasetFormat.csv(header=False), + ) + + statistics = Statistics.from_file_path( + statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json"), + sagemaker_session=sagemaker_session, + ) + + constraints = Constraints.from_file_path( + constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json"), + sagemaker_session=sagemaker_session, + ) + + my_default_monitor.create_monitoring_schedule( + batch_transform_input=batch_transform_input, + output_s3_uri=output_s3_uri, + statistics=statistics, + constraints=constraints, + schedule_cron_expression=HOURLY_CRON_EXPRESSION, + enable_cloudwatch_metrics=ENABLE_CLOUDWATCH_METRICS, + ) + + _wait_for_schedule_changes_to_apply(monitor=my_default_monitor) + + schedule_description = my_default_monitor.describe_schedule() + _verify_default_monitoring_schedule_with_batch_transform( + sagemaker_session=sagemaker_session, + schedule_description=schedule_description, + cron_expression=HOURLY_CRON_EXPRESSION, + statistics=statistics, + constraints=constraints, + output_kms_key=output_kms_key, + volume_kms_key=volume_kms_key, + network_config=NETWORK_CONFIG, + ) + + my_default_monitor.stop_monitoring_schedule() + + @pytest.mark.skipif( tests.integ.test_region() in tests.integ.NO_MODEL_MONITORING_REGIONS, reason="ModelMonitoring is not yet supported in this region.", @@ -1572,3 +1652,99 @@ def _verify_default_monitoring_schedule( ) else: assert network_config is None + + +def _verify_default_monitoring_schedule_with_batch_transform( + sagemaker_session, + schedule_description, + cron_expression=CronExpressionGenerator.daily(), + statistics=None, + constraints=None, + output_kms_key=None, + volume_kms_key=None, + instant_count=INSTANCE_COUNT, + instant_type=INSTANCE_TYPE, + volume_size_in_gb=VOLUME_SIZE_IN_GB, + network_config=None, + max_runtime_in_seconds=MAX_RUNTIME_IN_SECONDS, + publish_cloudwatch_metrics="Enabled", + env_key=ENV_KEY_1, + env_value=ENV_VALUE_1, + preprocessor=None, + postprocessor=None, + role=ROLE, +): + assert ( + schedule_description["MonitoringScheduleConfig"]["ScheduleConfig"]["ScheduleExpression"] + == cron_expression + ) + assert schedule_description["MonitoringScheduleConfig"]["MonitoringType"] == "DataQuality" + + job_definition_name = schedule_description["MonitoringScheduleConfig"].get( + "MonitoringJobDefinitionName" + ) + if job_definition_name: + job_desc = sagemaker_session.sagemaker_client.describe_data_quality_job_definition( + JobDefinitionName=job_definition_name, + ) + # app specification + app_specification = job_desc["DataQualityAppSpecification"] + env = app_specification["Environment"] + baseline_config = job_desc.get("DataQualityBaselineConfig") + job_input = job_desc["DataQualityJobInput"] + job_output_config = job_desc["DataQualityJobOutputConfig"] + client_config = job_desc["JobResources"]["ClusterConfig"] + else: + job_desc = schedule_description["MonitoringScheduleConfig"]["MonitoringJobDefinition"] + app_specification = job_desc["MonitoringAppSpecification"] + env = job_desc["Environment"] + baseline_config = job_desc.get("BaselineConfig") + job_input = job_desc["MonitoringInputs"][0] + job_output_config = job_desc["MonitoringOutputConfig"] + client_config = job_desc["MonitoringResources"]["ClusterConfig"] + + assert DEFAULT_IMAGE_SUFFIX in app_specification["ImageUri"] + if env.get(env_key): + assert env[env_key] == env_value + assert env["publish_cloudwatch_metrics"] == publish_cloudwatch_metrics + assert app_specification.get("RecordPreprocessorSourceUri") == preprocessor + assert app_specification.get("PostAnalyticsProcessorSourceUri") == postprocessor + + # baseline + if baseline_config: + if baseline_config["StatisticsResource"]: + assert baseline_config["StatisticsResource"]["S3Uri"] == statistics.file_s3_uri + else: + assert statistics is None + if baseline_config["ConstraintsResource"]: + assert baseline_config["ConstraintsResource"]["S3Uri"] == constraints.file_s3_uri + else: + assert constraints is None + else: + assert statistics is None + assert constraints is None + # job input + assert ( + "sagemaker-tensorflow-serving" + in job_input["BatchTransformInput"]["DataCapturedDestinationS3Uri"] + ) + # job output config + assert len(job_output_config["MonitoringOutputs"]) == 1 + assert job_output_config.get("KmsKeyId") == output_kms_key + # job resources + assert client_config["InstanceCount"] == instant_count + assert client_config["InstanceType"] == instant_type + assert client_config["VolumeSizeInGB"] == volume_size_in_gb + assert client_config.get("VolumeKmsKeyId") == volume_kms_key + # role + assert role in job_desc["RoleArn"] + # stop condition + assert job_desc["StoppingCondition"]["MaxRuntimeInSeconds"] == max_runtime_in_seconds + # network config + if job_desc.get("NetworkConfig"): + assert ( + job_desc["NetworkConfig"].get("EnableNetworkIsolation") + == network_config.enable_network_isolation + ) + else: + assert network_config is None diff --git a/tests/integ/test_transformer.py b/tests/integ/test_transformer.py index 7829b69c28..a0e37ffc77 100644 --- a/tests/integ/test_transformer.py +++ b/tests/integ/test_transformer.py @@ -18,12 +18,13 @@ import pytest -from sagemaker import KMeans, s3 +from sagemaker import KMeans, s3, get_execution_role from sagemaker.mxnet import MXNet from sagemaker.pytorch import PyTorchModel from sagemaker.tensorflow import TensorFlow from sagemaker.transformer import Transformer from sagemaker.estimator import Estimator +from sagemaker.inputs import BatchDataCaptureConfig from sagemaker.utils import unique_name_from_base from tests.integ import ( datasets, @@ -35,9 +36,46 @@ from tests.integ.timeout import timeout, timeout_and_delete_model_with_transformer from tests.integ.vpc_test_utils import get_or_create_vpc_resources +from sagemaker.model_monitor import DatasetFormat, Statistics + +from sagemaker.workflow.check_job_config import CheckJobConfig +from sagemaker.workflow.quality_check_step import ( + DataQualityCheckConfig, + ModelQualityCheckConfig, +) +from sagemaker.workflow.parameters import ParameterString +from sagemaker.s3 import S3Uploader +from sagemaker.clarify import ( + BiasConfig, + DataConfig, +) +from sagemaker.workflow.clarify_check_step import ( + DataBiasCheckConfig, +) + +_INSTANCE_COUNT = 1 +_INSTANCE_TYPE = "ml.c5.xlarge" +_HEADERS = ["Label", "F1", "F2", "F3", "F4"] +_CHECK_FAIL_ERROR_MSG_CLARIFY = "ClientError: Clarify check failed. See violation report" +_PROBLEM_TYPE = "Regression" +_HEADER_OF_LABEL = "Label" +_HEADER_OF_PREDICTED_LABEL = "Prediction" +_CHECK_FAIL_ERROR_MSG_QUALITY = "ClientError: Quality check failed. See violation report" + + MXNET_MNIST_PATH = os.path.join(DATA_DIR, "mxnet_mnist") +@pytest.fixture(scope="module") +def role(sagemaker_session): + return get_execution_role(sagemaker_session) + + +@pytest.fixture +def pipeline_name(): + return unique_name_from_base("my-pipeline-transform") + + @pytest.fixture(scope="module") def mxnet_estimator( sagemaker_session, @@ -78,6 +116,131 @@ def mxnet_transform_input(sagemaker_session): ) +@pytest.fixture +def check_job_config(role, pipeline_session): + return CheckJobConfig( + role=role, + instance_count=_INSTANCE_COUNT, + instance_type=_INSTANCE_TYPE, + volume_size_in_gb=60, + sagemaker_session=pipeline_session, + ) + + +@pytest.fixture +def supplied_baseline_statistics_uri_param(): + return ParameterString(name="SuppliedBaselineStatisticsUri", default_value="") + + +@pytest.fixture +def supplied_baseline_constraints_uri_param(): + return ParameterString(name="SuppliedBaselineConstraintsUri", default_value="") + + +@pytest.fixture +def dataset(pipeline_session): + dataset_local_path = os.path.join(DATA_DIR, "pipeline/clarify_check_step/dataset.csv") + dataset_s3_uri = "s3://{}/{}/{}/{}/{}".format( + pipeline_session.default_bucket(), + "clarify_check_step", + "input", + "dataset", + unique_name_from_base("dataset"), + ) + return S3Uploader.upload(dataset_local_path, dataset_s3_uri, sagemaker_session=pipeline_session) + + +@pytest.fixture +def data_config(pipeline_session, dataset): + output_path = "s3://{}/{}/{}/{}".format( + pipeline_session.default_bucket(), + "clarify_check_step", + "analysis_result", + unique_name_from_base("result"), + ) + analysis_cfg_output_path = "s3://{}/{}/{}/{}".format( + pipeline_session.default_bucket(), + "clarify_check_step", + "analysis_cfg", + unique_name_from_base("analysis_cfg"), + ) + return DataConfig( + s3_data_input_path=dataset, + s3_output_path=output_path, + s3_analysis_config_output_path=analysis_cfg_output_path, + label="Label", + headers=_HEADERS, + dataset_type="text/csv", + ) + + +@pytest.fixture +def bias_config(): + return BiasConfig( + label_values_or_threshold=[1], + facet_name="F1", + facet_values_or_threshold=[0.5], + group_name="F2", + ) + + +@pytest.fixture +def data_bias_check_config(data_config, bias_config): + return DataBiasCheckConfig( + data_config=data_config, + data_bias_config=bias_config, + ) + + +@pytest.fixture +def data_quality_baseline_dataset(): + return os.path.join(DATA_DIR, "pipeline/quality_check_step/data_quality/baseline_dataset.csv") + + +@pytest.fixture +def data_quality_check_config(data_quality_baseline_dataset): + return DataQualityCheckConfig( + baseline_dataset=data_quality_baseline_dataset, + dataset_format=DatasetFormat.csv(header=False), + ) + + +@pytest.fixture +def data_quality_supplied_baseline_statistics(sagemaker_session): + return Statistics.from_file_path( + statistics_file_path=os.path.join( + DATA_DIR, "pipeline/quality_check_step/data_quality/statistics.json" + ), + sagemaker_session=sagemaker_session, + ).file_s3_uri + + +@pytest.fixture +def model_quality_baseline_dataset(): + return os.path.join(DATA_DIR, "pipeline/quality_check_step/model_quality/baseline_dataset.csv") + + +@pytest.fixture +def model_quality_check_config(model_quality_baseline_dataset): + return ModelQualityCheckConfig( + baseline_dataset=model_quality_baseline_dataset, + dataset_format=DatasetFormat.csv(), + problem_type=_PROBLEM_TYPE, + inference_attribute=_HEADER_OF_LABEL, + ground_truth_attribute=_HEADER_OF_PREDICTED_LABEL, + ) + + +@pytest.fixture +def model_quality_supplied_baseline_statistics(sagemaker_session): + return Statistics.from_file_path( + statistics_file_path=os.path.join( + DATA_DIR, "pipeline/quality_check_step/model_quality/statistics.json" + ), + sagemaker_session=sagemaker_session, + ).file_s3_uri + + @pytest.mark.release def test_transform_mxnet( mxnet_estimator, mxnet_transform_input, sagemaker_session, cpu_instance_type @@ -248,6 +411,37 @@ def test_transform_model_client_config( assert model_client_config == transform_job_desc["ModelClientConfig"] +def test_transform_data_capture_config( + mxnet_estimator, mxnet_transform_input, sagemaker_session, cpu_instance_type +): + destination_s3_uri = os.path.join("s3://", sagemaker_session.default_bucket(), "data_capture") + batch_data_capture_config = BatchDataCaptureConfig( + destination_s3_uri=destination_s3_uri, kms_key_id="", generate_inference_id=False + ) + transformer = mxnet_estimator.transformer(1, cpu_instance_type) + + # we extract the S3Prefix from the input + filename = mxnet_transform_input.split("/")[-1] + input_prefix = mxnet_transform_input.replace(f"/{filename}", "") + transformer.transform( + input_prefix, + content_type="text/csv", + batch_data_capture_config=batch_data_capture_config, + ) + + with timeout_and_delete_model_with_transformer( + transformer, sagemaker_session, minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES + ): + transformer.wait() + transform_job_desc = sagemaker_session.sagemaker_client.describe_transform_job( + TransformJobName=transformer.latest_transform_job.name + ) + + assert ( + batch_data_capture_config._to_request_dict() == transform_job_desc["DataCaptureConfig"] + ) + + def test_transform_byo_estimator(sagemaker_session, cpu_instance_type): tags = [{"Key": "some-tag", "Value": "value-for-tag"}] diff --git a/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py b/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py index 9b92cac96f..1ca310a30a 100644 --- a/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py +++ b/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py @@ -36,6 +36,7 @@ Constraints, CronExpressionGenerator, EndpointInput, + BatchTransformInput, ExplainabilityAnalysisConfig, ModelBiasMonitor, ModelExplainabilityMonitor, @@ -48,6 +49,7 @@ ClarifyBaseliningJob, ClarifyMonitoringExecution, ) +from sagemaker.model_monitor.dataset_format import MonitoringDatasetFormat # shared CLARIFY_IMAGE_URI = "306415355426.dkr.ecr.us-west-2.amazonaws.com/sagemaker-clarify-processing:1.0" @@ -62,6 +64,7 @@ S3_INPUT_MODE = "File" S3_DATA_DISTRIBUTION_TYPE = "FullyReplicated" S3_UPLOAD_MODE = "Continuous" +DATASET_FORMAT = MonitoringDatasetFormat.csv(header=False) # For create API ROLE = "SageMakerRole" @@ -91,6 +94,8 @@ ANALYSIS_CONFIG_S3_URI = "s3://bucket/analysis_config.json" START_TIME_OFFSET = "-PT1H" END_TIME_OFFSET = "-PT0H" +DATA_CAPTURED_S3_URI = "s3://my-bucket/batch-fraud-detection/on-schedule-monitoring/in/" +SCHEDULE_DESTINATION = "/opt/ml/processing/data" OUTPUT_S3_URI = "s3://bucket/output" CONSTRAINTS = Constraints("", "s3://bucket/analysis.json") FEATURES_ATTRIBUTE = "features" @@ -138,6 +143,22 @@ }, "GroundTruthS3Input": {"S3Uri": GROUND_TRUTH_S3_URI}, } +BIAS_BATCH_TRANSFORM_JOB_INPUT = { + "BatchTransformInput": { + "DataCapturedDestinationS3Uri": DATA_CAPTURED_S3_URI, + "LocalPath": SCHEDULE_DESTINATION, + "S3InputMode": S3_INPUT_MODE, + "S3DataDistributionType": S3_DATA_DISTRIBUTION_TYPE, + "StartTimeOffset": START_TIME_OFFSET, + "EndTimeOffset": END_TIME_OFFSET, + "FeaturesAttribute": FEATURES_ATTRIBUTE, + "InferenceAttribute": str(INFERENCE_ATTRIBUTE), + "ProbabilityAttribute": str(PROBABILITY_ATTRIBUTE), + "ProbabilityThresholdAttribute": PROBABILITY_THRESHOLD_ATTRIBUTE, + "DatasetFormat": DATASET_FORMAT, + }, + "GroundTruthS3Input": {"S3Uri": GROUND_TRUTH_S3_URI}, +} STOP_CONDITION = {"MaxRuntimeInSeconds": MAX_RUNTIME_IN_SECONDS} BIAS_JOB_DEFINITION = { "ModelBiasAppSpecification": APP_SPECIFICATION, @@ -149,6 +170,17 @@ "NetworkConfig": NETWORK_CONFIG._to_request_dict(), "StoppingCondition": STOP_CONDITION, } +BIAS_BATCH_TRANSFORM_JOB_DEFINITION = { + "ModelBiasAppSpecification": APP_SPECIFICATION, + "ModelBiasJobInput": BIAS_BATCH_TRANSFORM_JOB_INPUT, + "ModelBiasJobOutputConfig": JOB_OUTPUT_CONFIG, + "JobResources": JOB_RESOURCES, + "RoleArn": ROLE_ARN, + "ModelBiasBaselineConfig": BASELINE_CONFIG, + "NetworkConfig": NETWORK_CONFIG._to_request_dict(), + "StoppingCondition": STOP_CONDITION, +} + EXPLAINABILITY_JOB_INPUT = { "EndpointInput": { "EndpointName": ENDPOINT_NAME, @@ -159,6 +191,17 @@ "InferenceAttribute": str(INFERENCE_ATTRIBUTE), } } +EXPLAINABILITY_BATCH_TRANSFORM_JOB_INPUT = { + "BatchTransformInput": { + "DataCapturedDestinationS3Uri": DATA_CAPTURED_S3_URI, + "LocalPath": SCHEDULE_DESTINATION, + "S3InputMode": S3_INPUT_MODE, + "S3DataDistributionType": S3_DATA_DISTRIBUTION_TYPE, + "FeaturesAttribute": FEATURES_ATTRIBUTE, + "InferenceAttribute": str(INFERENCE_ATTRIBUTE), + "DatasetFormat": DATASET_FORMAT, + } +} EXPLAINABILITY_JOB_DEFINITION = { "ModelExplainabilityAppSpecification": APP_SPECIFICATION, "ModelExplainabilityJobInput": EXPLAINABILITY_JOB_INPUT, @@ -168,6 +211,15 @@ "RoleArn": ROLE_ARN, "NetworkConfig": NETWORK_CONFIG._to_request_dict(), } +EXPLAINABILITY__BATCH_TRANSFORM_JOB_DEFINITION = { + "ModelExplainabilityAppSpecification": APP_SPECIFICATION, + "ModelExplainabilityJobInput": EXPLAINABILITY_BATCH_TRANSFORM_JOB_INPUT, + "ModelExplainabilityJobOutputConfig": JOB_OUTPUT_CONFIG, + "JobResources": JOB_RESOURCES, + "StoppingCondition": STOP_CONDITION, + "RoleArn": ROLE_ARN, + "NetworkConfig": NETWORK_CONFIG._to_request_dict(), +} # For update API NEW_ROLE_ARN = "arn:aws:iam::012345678902:role/{}".format(ROLE) @@ -716,6 +768,28 @@ def test_model_bias_monitor(model_bias_monitor, sagemaker_session): ) +def test_model_batch_transform_bias_monitor(model_bias_monitor, sagemaker_session): + # create schedule + _test_model_bias_monitor_batch_transform_create_schedule( + model_bias_monitor=model_bias_monitor, + sagemaker_session=sagemaker_session, + analysis_config=ANALYSIS_CONFIG_S3_URI, + constraints=CONSTRAINTS, + ) + + # update schedule + _test_model_bias_monitor_update_schedule( + model_bias_monitor=model_bias_monitor, + sagemaker_session=sagemaker_session, + ) + + # delete schedule + _test_model_bias_monitor_delete_schedule( + model_bias_monitor=model_bias_monitor, + sagemaker_session=sagemaker_session, + ) + + def test_model_bias_monitor_created_with_config(model_bias_monitor, sagemaker_session, bias_config): # create schedule analysis_config = BiasAnalysisConfig( @@ -903,6 +977,71 @@ def _test_model_bias_monitor_create_schedule( ) +def _test_model_bias_monitor_batch_transform_create_schedule( + model_bias_monitor, + sagemaker_session, + analysis_config=None, + constraints=None, + baseline_job_name=None, + batch_transform_input=BatchTransformInput( + data_captured_destination_s3_uri=DATA_CAPTURED_S3_URI, + destination=SCHEDULE_DESTINATION, + start_time_offset=START_TIME_OFFSET, + end_time_offset=END_TIME_OFFSET, + features_attribute=FEATURES_ATTRIBUTE, + inference_attribute=str(INFERENCE_ATTRIBUTE), + probability_attribute=str(PROBABILITY_ATTRIBUTE), + probability_threshold_attribute=PROBABILITY_THRESHOLD_ATTRIBUTE, + dataset_format=MonitoringDatasetFormat.csv(header=False), + ), +): + # create schedule + with patch( + "sagemaker.s3.S3Uploader.upload_string_as_file_body", return_value=ANALYSIS_CONFIG_S3_URI + ) as upload: + model_bias_monitor.create_monitoring_schedule( + batch_transform_input=batch_transform_input, + ground_truth_input=GROUND_TRUTH_S3_URI, + analysis_config=analysis_config, + output_s3_uri=OUTPUT_S3_URI, + constraints=constraints, + monitor_schedule_name=SCHEDULE_NAME, + schedule_cron_expression=CRON_HOURLY, + ) + if not isinstance(analysis_config, str): + upload.assert_called_once() + assert json.loads(upload.call_args[0][0]) == BIAS_ANALYSIS_CONFIG + + # validation + expected_arguments = { + "JobDefinitionName": model_bias_monitor.job_definition_name, + **copy.deepcopy(BIAS_BATCH_TRANSFORM_JOB_DEFINITION), + "Tags": TAGS, + } + if constraints: + expected_arguments["ModelBiasBaselineConfig"] = { + "ConstraintsResource": {"S3Uri": constraints.file_s3_uri} + } + elif baseline_job_name: + expected_arguments["ModelBiasBaselineConfig"] = { + "BaseliningJobName": baseline_job_name, + } + + sagemaker_session.sagemaker_client.create_model_bias_job_definition.assert_called_with( + **expected_arguments + ) + + sagemaker_session.sagemaker_client.create_monitoring_schedule.assert_called_with( + MonitoringScheduleName=SCHEDULE_NAME, + MonitoringScheduleConfig={ + "MonitoringJobDefinitionName": model_bias_monitor.job_definition_name, + "MonitoringType": "ModelBias", + "ScheduleConfig": {"ScheduleExpression": CRON_HOURLY}, + }, + Tags=TAGS, + ) + + def _test_model_bias_monitor_update_schedule(model_bias_monitor, sagemaker_session): # update schedule sagemaker_session.describe_monitoring_schedule = MagicMock() @@ -1146,6 +1285,30 @@ def test_model_explainability_monitor(model_explainability_monitor, sagemaker_se ) +def test_model_explainability_batch_transform_monitor( + model_explainability_monitor, sagemaker_session +): + # create schedule + _test_model_explainability_batch_transform_monitor_create_schedule( + model_explainability_monitor=model_explainability_monitor, + sagemaker_session=sagemaker_session, + analysis_config=ANALYSIS_CONFIG_S3_URI, + constraints=CONSTRAINTS, + ) + + # update schedule + _test_model_explainability_monitor_update_schedule( + model_explainability_monitor=model_explainability_monitor, + sagemaker_session=sagemaker_session, + ) + + # delete schedule + _test_model_explainability_monitor_delete_schedule( + model_explainability_monitor=model_explainability_monitor, + sagemaker_session=sagemaker_session, + ) + + def test_model_explainability_monitor_created_with_config( model_explainability_monitor, sagemaker_session, shap_config, model_config ): @@ -1339,6 +1502,67 @@ def _test_model_explainability_monitor_create_schedule( ) +def _test_model_explainability_batch_transform_monitor_create_schedule( + model_explainability_monitor, + sagemaker_session, + analysis_config=None, + constraints=None, + baseline_job_name=None, + batch_transform_input=BatchTransformInput( + data_captured_destination_s3_uri=DATA_CAPTURED_S3_URI, + destination=SCHEDULE_DESTINATION, + features_attribute=FEATURES_ATTRIBUTE, + inference_attribute=str(INFERENCE_ATTRIBUTE), + dataset_format=MonitoringDatasetFormat.csv(header=False), + ), + explainability_analysis_config=None, +): + # create schedule + with patch( + "sagemaker.s3.S3Uploader.upload_string_as_file_body", return_value=ANALYSIS_CONFIG_S3_URI + ) as upload: + model_explainability_monitor.create_monitoring_schedule( + batch_transform_input=batch_transform_input, + analysis_config=analysis_config, + output_s3_uri=OUTPUT_S3_URI, + constraints=constraints, + monitor_schedule_name=SCHEDULE_NAME, + schedule_cron_expression=CRON_HOURLY, + ) + if not isinstance(analysis_config, str): + upload.assert_called_once() + assert json.loads(upload.call_args[0][0]) == explainability_analysis_config + + # validation + expected_arguments = { + "JobDefinitionName": model_explainability_monitor.job_definition_name, + **copy.deepcopy(EXPLAINABILITY__BATCH_TRANSFORM_JOB_DEFINITION), + "Tags": TAGS, + } + if constraints: + expected_arguments["ModelExplainabilityBaselineConfig"] = { + "ConstraintsResource": {"S3Uri": constraints.file_s3_uri} + } + elif baseline_job_name: + expected_arguments["ModelExplainabilityBaselineConfig"] = { + "BaseliningJobName": baseline_job_name, + } + + sagemaker_session.sagemaker_client.create_model_explainability_job_definition.assert_called_with( + **expected_arguments + ) + + sagemaker_session.sagemaker_client.create_monitoring_schedule.assert_called_with( + MonitoringScheduleName=SCHEDULE_NAME, + MonitoringScheduleConfig={ + "MonitoringJobDefinitionName": model_explainability_monitor.job_definition_name, + "MonitoringType": "ModelExplainability", + "ScheduleConfig": {"ScheduleExpression": CRON_HOURLY}, + }, + Tags=TAGS, + ) + + def _test_model_explainability_monitor_update_schedule( model_explainability_monitor, sagemaker_session ): diff --git a/tests/unit/sagemaker/monitor/test_model_monitoring.py b/tests/unit/sagemaker/monitor/test_model_monitoring.py index e375de9ce6..d622d10100 100644 --- a/tests/unit/sagemaker/monitor/test_model_monitoring.py +++ b/tests/unit/sagemaker/monitor/test_model_monitoring.py @@ -23,12 +23,13 @@ CronExpressionGenerator, DefaultModelMonitor, EndpointInput, + BatchTransformInput, ModelQualityMonitor, Statistics, ) -from sagemaker.model_monitor.dataset_format import DatasetFormat from sagemaker.network import NetworkConfig +from sagemaker.model_monitor.dataset_format import MonitoringDatasetFormat, DatasetFormat REGION = "us-west-2" BUCKET_NAME = "mybucket" @@ -101,6 +102,7 @@ SCHEDULE_ARN = "arn:aws:sagemaker:us-west-2:012345678901:monitoring-schedule/" + SCHEDULE_NAME OUTPUT_LOCAL_PATH = "/opt/ml/processing/output" ENDPOINT_INPUT_LOCAL_PATH = "/opt/ml/processing/input/endpoint" +SCHEDULE_DESTINATION = "/opt/ml/processing/data" SCHEDULE_NAME = "schedule" CRON_HOURLY = CronExpressionGenerator.hourly() S3_INPUT_MODE = "File" @@ -119,6 +121,8 @@ PROBABILITY_THRESHOLD_ATTRIBUTE = 0.6 PREPROCESSOR_URI = "s3://my_bucket/preprocessor.py" POSTPROCESSOR_URI = "s3://my_bucket/postprocessor.py" +DATA_CAPTURED_S3_URI = "s3://my-bucket/batch-fraud-detection/on-schedule-monitoring/in/" +DATASET_FORMAT = MonitoringDatasetFormat.csv(header=False) JOB_OUTPUT_CONFIG = { "MonitoringOutputs": [ { @@ -148,6 +152,15 @@ "S3DataDistributionType": S3_DATA_DISTRIBUTION_TYPE, }, } +DATA_QUALITY_BATCH_TRANSFORM_INPUT = { + "BatchTransformInput": { + "DataCapturedDestinationS3Uri": DATA_CAPTURED_S3_URI, + "LocalPath": SCHEDULE_DESTINATION, + "S3DataDistributionType": S3_DATA_DISTRIBUTION_TYPE, + "S3InputMode": S3_INPUT_MODE, + "DatasetFormat": DATASET_FORMAT, + } +} DATA_QUALITY_APP_SPECIFICATION = { "ImageUri": DEFAULT_IMAGE_URI, "Environment": ENVIRONMENT, @@ -168,6 +181,16 @@ "NetworkConfig": NETWORK_CONFIG._to_request_dict(), "StoppingCondition": STOP_CONDITION, } +DATA_QUALITY_BATCH_TRANSFORM_JOB_DEFINITION = { + "DataQualityAppSpecification": DATA_QUALITY_APP_SPECIFICATION, + "DataQualityBaselineConfig": DATA_QUALITY_BASELINE_CONFIG, + "DataQualityJobInput": DATA_QUALITY_BATCH_TRANSFORM_INPUT, + "DataQualityJobOutputConfig": JOB_OUTPUT_CONFIG, + "JobResources": JOB_RESOURCES, + "RoleArn": ROLE, + "NetworkConfig": NETWORK_CONFIG._to_request_dict(), + "StoppingCondition": STOP_CONDITION, +} MODEL_QUALITY_APP_SPECIFICATION = { "ImageUri": DEFAULT_IMAGE_URI, "ProblemType": PROBLEM_TYPE, @@ -191,6 +214,24 @@ }, "GroundTruthS3Input": {"S3Uri": GROUND_TRUTH_S3_URI}, } + +MODEL_QUALITY_BATCH_TRANSFORM_INPUT_JOB_INPUT = { + "BatchTransformInput": { + "DataCapturedDestinationS3Uri": DATA_CAPTURED_S3_URI, + "LocalPath": SCHEDULE_DESTINATION, + "S3InputMode": S3_INPUT_MODE, + "S3DataDistributionType": S3_DATA_DISTRIBUTION_TYPE, + "StartTimeOffset": START_TIME_OFFSET, + "EndTimeOffset": END_TIME_OFFSET, + "FeaturesAttribute": FEATURES_ATTRIBUTE, + "InferenceAttribute": INFERENCE_ATTRIBUTE, + "ProbabilityAttribute": PROBABILITY_ATTRIBUTE, + "ProbabilityThresholdAttribute": PROBABILITY_THRESHOLD_ATTRIBUTE, + "DatasetFormat": DATASET_FORMAT, + }, + "GroundTruthS3Input": {"S3Uri": GROUND_TRUTH_S3_URI}, +} + MODEL_QUALITY_JOB_DEFINITION = { "ModelQualityAppSpecification": MODEL_QUALITY_APP_SPECIFICATION, "ModelQualityJobInput": MODEL_QUALITY_JOB_INPUT, @@ -202,6 +243,17 @@ "StoppingCondition": STOP_CONDITION, } +MODEL_QUALITY_BATCH_TRANSFORM_INPUT_JOB_DEFINITION = { + "ModelQualityAppSpecification": MODEL_QUALITY_APP_SPECIFICATION, + "ModelQualityJobInput": MODEL_QUALITY_BATCH_TRANSFORM_INPUT_JOB_INPUT, + "ModelQualityJobOutputConfig": JOB_OUTPUT_CONFIG, + "JobResources": JOB_RESOURCES, + "RoleArn": ROLE, + "ModelQualityBaselineConfig": MODEL_QUALITY_BASELINE_CONFIG, + "NetworkConfig": NETWORK_CONFIG._to_request_dict(), + "StoppingCondition": STOP_CONDITION, +} + # For update API NEW_ROLE_ARN = "arn:aws:iam::012345678902:role/{}".format(ROLE) NEW_INSTANCE_COUNT = 2 @@ -480,6 +532,28 @@ def test_data_quality_monitor(data_quality_monitor, sagemaker_session): ) +def test_data_quality_batch_transform_monitor(data_quality_monitor, sagemaker_session): + # create schedule + _test_data_quality_batch_transform_monitor_create_schedule( + data_quality_monitor=data_quality_monitor, + sagemaker_session=sagemaker_session, + constraints=CONSTRAINTS, + statistics=STATISTICS, + ) + + # update schedule + _test_data_quality_monitor_update_schedule( + data_quality_monitor=data_quality_monitor, + sagemaker_session=sagemaker_session, + ) + + # delete schedule + _test_data_quality_monitor_delete_schedule( + data_quality_monitor=data_quality_monitor, + sagemaker_session=sagemaker_session, + ) + + def test_data_quality_monitor_created_by_attach(sagemaker_session): # attach and validate sagemaker_session.sagemaker_client.describe_data_quality_job_definition = MagicMock() @@ -600,6 +674,7 @@ def _test_data_quality_monitor_create_schedule( endpoint_name=ENDPOINT_NAME, destination=ENDPOINT_INPUT_LOCAL_PATH ), ): + # for endpoint input data_quality_monitor.create_monitoring_schedule( endpoint_input=endpoint_input, record_preprocessor_script=PREPROCESSOR_URI, @@ -625,6 +700,45 @@ def _test_data_quality_monitor_create_schedule( **expected_arguments ) + +def _test_data_quality_batch_transform_monitor_create_schedule( + data_quality_monitor, + sagemaker_session, + constraints=None, + statistics=None, + baseline_job_name=None, + batch_transform_input=BatchTransformInput( + data_captured_destination_s3_uri=DATA_CAPTURED_S3_URI, + destination=SCHEDULE_DESTINATION, + dataset_format=MonitoringDatasetFormat.csv(header=False), + ), +): + # for batch transform input + data_quality_monitor.create_monitoring_schedule( + batch_transform_input=batch_transform_input, + record_preprocessor_script=PREPROCESSOR_URI, + post_analytics_processor_script=POSTPROCESSOR_URI, + output_s3_uri=OUTPUT_S3_URI, + constraints=constraints, + statistics=statistics, + monitor_schedule_name=SCHEDULE_NAME, + schedule_cron_expression=CRON_HOURLY, + ) + + # validation + expected_arguments = { + "JobDefinitionName": data_quality_monitor.job_definition_name, + **copy.deepcopy(DATA_QUALITY_BATCH_TRANSFORM_JOB_DEFINITION), + "Tags": TAGS, + } + if baseline_job_name: + baseline_config = expected_arguments.get("DataQualityBaselineConfig", {}) + baseline_config["BaseliningJobName"] = baseline_job_name + + sagemaker_session.sagemaker_client.create_data_quality_job_definition.assert_called_with( + **expected_arguments + ) + sagemaker_session.sagemaker_client.create_monitoring_schedule.assert_called_with( MonitoringScheduleName=SCHEDULE_NAME, MonitoringScheduleConfig={ @@ -881,6 +995,27 @@ def test_model_quality_monitor(model_quality_monitor, sagemaker_session): ) +def test_model_quality_batch_transform_monitor(model_quality_monitor, sagemaker_session): + # create schedule + _test_model_quality_monitor_batch_transform_create_schedule( + model_quality_monitor=model_quality_monitor, + sagemaker_session=sagemaker_session, + constraints=CONSTRAINTS, + ) + + # update schedule + _test_model_quality_monitor_update_schedule( + model_quality_monitor=model_quality_monitor, + sagemaker_session=sagemaker_session, + ) + + # delete schedule + _test_model_quality_monitor_delete_schedule( + model_quality_monitor=model_quality_monitor, + sagemaker_session=sagemaker_session, + ) + + def test_model_quality_monitor_created_by_attach(sagemaker_session): # attach and validate sagemaker_session.sagemaker_client.describe_model_quality_job_definition = MagicMock() @@ -1048,6 +1183,65 @@ def _test_model_quality_monitor_create_schedule( ) +def _test_model_quality_monitor_batch_transform_create_schedule( + model_quality_monitor, + sagemaker_session, + constraints=None, + baseline_job_name=None, + batch_transform_input=BatchTransformInput( + data_captured_destination_s3_uri=DATA_CAPTURED_S3_URI, + destination=SCHEDULE_DESTINATION, + start_time_offset=START_TIME_OFFSET, + end_time_offset=END_TIME_OFFSET, + features_attribute=FEATURES_ATTRIBUTE, + inference_attribute=INFERENCE_ATTRIBUTE, + probability_attribute=PROBABILITY_ATTRIBUTE, + probability_threshold_attribute=PROBABILITY_THRESHOLD_ATTRIBUTE, + dataset_format=MonitoringDatasetFormat.csv(header=False), + ), +): + model_quality_monitor.create_monitoring_schedule( + batch_transform_input=batch_transform_input, + ground_truth_input=GROUND_TRUTH_S3_URI, + problem_type=PROBLEM_TYPE, + record_preprocessor_script=PREPROCESSOR_URI, + post_analytics_processor_script=POSTPROCESSOR_URI, + output_s3_uri=OUTPUT_S3_URI, + constraints=constraints, + monitor_schedule_name=SCHEDULE_NAME, + schedule_cron_expression=CRON_HOURLY, + ) + + # validation + expected_arguments = { + "JobDefinitionName": model_quality_monitor.job_definition_name, + **copy.deepcopy(MODEL_QUALITY_BATCH_TRANSFORM_INPUT_JOB_DEFINITION), + "Tags": TAGS, + } + if constraints: + expected_arguments["ModelQualityBaselineConfig"] = { + "ConstraintsResource": {"S3Uri": constraints.file_s3_uri} + } + if baseline_job_name: + expected_arguments["ModelQualityBaselineConfig"] = { + "BaseliningJobName": baseline_job_name, + } + + sagemaker_session.sagemaker_client.create_model_quality_job_definition.assert_called_with( + **expected_arguments + ) + + sagemaker_session.sagemaker_client.create_monitoring_schedule.assert_called_with( + MonitoringScheduleName=SCHEDULE_NAME, + MonitoringScheduleConfig={ + "MonitoringJobDefinitionName": model_quality_monitor.job_definition_name, + "MonitoringType": "ModelQuality", + "ScheduleConfig": {"ScheduleExpression": CRON_HOURLY}, + }, + Tags=TAGS, + ) + + def _test_model_quality_monitor_update_schedule(model_quality_monitor, sagemaker_session): # update schedule sagemaker_session.describe_monitoring_schedule = MagicMock() @@ -1248,3 +1442,42 @@ def _test_model_quality_monitor_delete_schedule(model_quality_monitor, sagemaker sagemaker_session.sagemaker_client.delete_model_quality_job_definition.assert_called_once_with( JobDefinitionName=job_definition_name ) + + +def test_batch_transform_and_endpoint_input_simultaneous_failure( + data_quality_monitor, + sagemaker_session, + constraints=None, + statistics=None, + baseline_job_name=None, + batch_transform_input=BatchTransformInput( + data_captured_destination_s3_uri=DATA_CAPTURED_S3_URI, + destination=SCHEDULE_DESTINATION, + dataset_format=MonitoringDatasetFormat.csv(header=False), + ), + endpoint_input=EndpointInput( + endpoint_name=ENDPOINT_NAME, + destination=ENDPOINT_INPUT_LOCAL_PATH, + start_time_offset=START_TIME_OFFSET, + end_time_offset=END_TIME_OFFSET, + features_attribute=FEATURES_ATTRIBUTE, + inference_attribute=INFERENCE_ATTRIBUTE, + probability_attribute=PROBABILITY_ATTRIBUTE, + probability_threshold_attribute=PROBABILITY_THRESHOLD_ATTRIBUTE, + ), +): + try: + # for batch transform input + data_quality_monitor.create_monitoring_schedule( + batch_transform_input=batch_transform_input, + record_preprocessor_script=PREPROCESSOR_URI, + post_analytics_processor_script=POSTPROCESSOR_URI, + output_s3_uri=OUTPUT_S3_URI, + constraints=constraints, + statistics=statistics, + monitor_schedule_name=SCHEDULE_NAME, + schedule_cron_expression=CRON_HOURLY, + endpoint_input=endpoint_input, + ) + except Exception as e: + assert "Need to have either batch_transform_input or endpoint_input" in str(e) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 78df274b71..a74348e1e7 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -32,6 +32,7 @@ NOTEBOOK_METADATA_FILE, ) from sagemaker.tuner import WarmStartConfig, WarmStartTypes +from sagemaker.inputs import BatchDataCaptureConfig STATIC_HPs = {"feature_dim": "784"} @@ -1373,6 +1374,7 @@ def test_transform_pack_to_request(sagemaker_session): model_client_config=None, tags=None, data_processing=data_processing, + batch_data_capture_config=None, ) _, _, actual_args = sagemaker_session.sagemaker_client.method_calls[0] @@ -1385,6 +1387,12 @@ def test_transform_pack_to_request_with_optional_params(sagemaker_session): max_payload = 0 env = {"FOO": "BAR"} + batch_data_capture_config = BatchDataCaptureConfig( + destination_s3_uri="test_uri", + kms_key_id="", + generate_inference_id=False, + ) + sagemaker_session.transform( job_name=JOB_NAME, model_name="my-model", @@ -1399,6 +1407,7 @@ def test_transform_pack_to_request_with_optional_params(sagemaker_session): model_client_config=MODEL_CLIENT_CONFIG, tags=TAGS, data_processing=None, + batch_data_capture_config=batch_data_capture_config, ) _, _, actual_args = sagemaker_session.sagemaker_client.method_calls[0] @@ -1409,6 +1418,7 @@ def test_transform_pack_to_request_with_optional_params(sagemaker_session): assert actual_args["Tags"] == TAGS assert actual_args["ExperimentConfig"] == EXPERIMENT_CONFIG assert actual_args["ModelClientConfig"] == MODEL_CLIENT_CONFIG + assert actual_args["DataCaptureConfig"] == batch_data_capture_config._to_request_dict() @patch("sys.stdout", new_callable=io.BytesIO if six.PY2 else io.StringIO) diff --git a/tests/unit/test_transformer.py b/tests/unit/test_transformer.py index 5e9b1009d6..3ddb23bb0c 100644 --- a/tests/unit/test_transformer.py +++ b/tests/unit/test_transformer.py @@ -16,6 +16,7 @@ from mock import MagicMock, Mock, patch from sagemaker.transformer import _TransformJob, Transformer +from sagemaker.inputs import BatchDataCaptureConfig from tests.integ import test_local_mode MODEL_NAME = "model" @@ -168,6 +169,9 @@ def test_transform_with_all_params(start_new_job, transformer): "TrialComponentDisplayName": "tc", } model_client_config = {"InvocationsTimeoutInSeconds": 60, "InvocationsMaxRetries": 2} + batch_data_capture_config = BatchDataCaptureConfig( + destination_s3_uri=OUTPUT_PATH, kms_key_id=KMS_KEY_ID, generate_inference_id=False + ) transformer.transform( DATA, @@ -181,6 +185,7 @@ def test_transform_with_all_params(start_new_job, transformer): join_source=join_source, experiment_config=experiment_config, model_client_config=model_client_config, + batch_data_capture_config=batch_data_capture_config, ) assert transformer._current_job_name == JOB_NAME @@ -197,6 +202,7 @@ def test_transform_with_all_params(start_new_job, transformer): join_source, experiment_config, model_client_config, + batch_data_capture_config, ) @@ -433,6 +439,10 @@ def test_start_new(prepare_data_processing, load_config, sagemaker_session): join_source = "Input" model_client_config = {"InvocationsTimeoutInSeconds": 60, "InvocationsMaxRetries": 2} + batch_data_capture_config = BatchDataCaptureConfig( + destination_s3_uri=OUTPUT_PATH, kms_key_id=KMS_KEY_ID, generate_inference_id=False + ) + job = _TransformJob.start_new( transformer=transformer, data=DATA, @@ -445,6 +455,7 @@ def test_start_new(prepare_data_processing, load_config, sagemaker_session): join_source=join_source, experiment_config={"ExperimentName": "exp"}, model_client_config=model_client_config, + batch_data_capture_config=batch_data_capture_config, ) assert job.sagemaker_session == sagemaker_session @@ -469,6 +480,7 @@ def test_start_new(prepare_data_processing, load_config, sagemaker_session): model_client_config=model_client_config, tags=tags, data_processing=prepare_data_processing.return_value, + batch_data_capture_config=batch_data_capture_config, )