Skip to content

feature: added support for batch transform with model monitoring #3418

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions src/sagemaker/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ class TransformInput(object):
output_filter: str = attr.ib(default=None)
join_source: str = attr.ib(default=None)
model_client_config: dict = attr.ib(default=None)
batch_data_capture_config: dict = attr.ib(default=None)


class FileSystemInput(object):
Expand Down Expand Up @@ -232,3 +233,45 @@ def __init__(

if content_type:
self.config["ContentType"] = content_type


class BatchDataCaptureConfig(object):
"""Configuration object passed in when create a batch transform job.

Specifies configuration related to batch transform job data capture for use with
Amazon SageMaker Model Monitoring
"""

def __init__(
self,
destination_s3_uri: str,
kms_key_id: str = None,
generate_inference_id: bool = None,
):
"""Create new BatchDataCaptureConfig

Args:
destination_s3_uri (str): S3 Location to store the captured data
kms_key_id (str): The KMS key to use when writing to S3.
KmsKeyId can be an ID of a KMS key, ARN of a KMS key, alias of a KMS key,
or alias of a KMS key. The KmsKeyId is applied to all outputs.
(default: None)
generate_inference_id (bool): Flag to generate an inference id
(default: None)
"""
self.destination_s3_uri = destination_s3_uri
self.kms_key_id = kms_key_id
self.generate_inference_id = generate_inference_id

def _to_request_dict(self):
"""Generates a request dictionary using the parameters provided to the class."""
batch_data_capture_config = {
"DestinationS3Uri": self.destination_s3_uri,
}

if self.kms_key_id is not None:
batch_data_capture_config["KmsKeyId"] = self.kms_key_id
if self.generate_inference_id is not None:
batch_data_capture_config["GenerateInferenceId"] = self.generate_inference_id

return batch_data_capture_config
2 changes: 2 additions & 0 deletions src/sagemaker/model_monitor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from sagemaker.model_monitor.model_monitoring import BaseliningJob # noqa: F401
from sagemaker.model_monitor.model_monitoring import MonitoringExecution # noqa: F401
from sagemaker.model_monitor.model_monitoring import EndpointInput # noqa: F401
from sagemaker.model_monitor.model_monitoring import BatchTransformInput # noqa: F401
from sagemaker.model_monitor.model_monitoring import MonitoringOutput # noqa: F401
from sagemaker.model_monitor.model_monitoring import ModelQualityMonitor # noqa: F401

Expand All @@ -42,5 +43,6 @@

from sagemaker.model_monitor.data_capture_config import DataCaptureConfig # noqa: F401
from sagemaker.model_monitor.dataset_format import DatasetFormat # noqa: F401
from sagemaker.model_monitor.dataset_format import MonitoringDatasetFormat # noqa: F401

from sagemaker.network import NetworkConfig # noqa: F401
104 changes: 93 additions & 11 deletions src/sagemaker/model_monitor/clarify_model_monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def _build_create_job_definition_request(
env=None,
tags=None,
network_config=None,
batch_transform_input=None,
):
"""Build the request for job definition creation API

Expand Down Expand Up @@ -270,6 +271,8 @@ def _build_create_job_definition_request(
network_config (sagemaker.network.NetworkConfig): A NetworkConfig
object that configures network isolation, encryption of
inter-container traffic, security group IDs, and subnets.
batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to run
the monitoring schedule on the batch transform

Returns:
dict: request parameters to create job definition.
Expand Down Expand Up @@ -366,6 +369,27 @@ def _build_create_job_definition_request(
latest_baselining_job_config.probability_threshold_attribute
)
job_input = normalized_endpoint_input._to_request_dict()
elif batch_transform_input is not None:
# backfill attributes to batch transform input
if latest_baselining_job_config is not None:
if batch_transform_input.features_attribute is None:
batch_transform_input.features_attribute = (
latest_baselining_job_config.features_attribute
)
if batch_transform_input.inference_attribute is None:
batch_transform_input.inference_attribute = (
latest_baselining_job_config.inference_attribute
)
if batch_transform_input.probability_attribute is None:
batch_transform_input.probability_attribute = (
latest_baselining_job_config.probability_attribute
)
if batch_transform_input.probability_threshold_attribute is None:
batch_transform_input.probability_threshold_attribute = (
latest_baselining_job_config.probability_threshold_attribute
)
job_input = batch_transform_input._to_request_dict()

if ground_truth_input is not None:
job_input["GroundTruthS3Input"] = dict(S3Uri=ground_truth_input)

Expand Down Expand Up @@ -500,37 +524,46 @@ def suggest_baseline(
# noinspection PyMethodOverriding
def create_monitoring_schedule(
self,
endpoint_input,
ground_truth_input,
endpoint_input=None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changing the order of arguments might break for Cx's existing workflows.

Copy link
Contributor

@jerrypeng7773 jerrypeng7773 Oct 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, but endpoint_input became optional, and non-none ground_truth_input has to come first.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

talked offline with @mufaddal-rohawala, we default ground_truth_input to None in the signature, however, we raise value error if it's None within the function.

ground_truth_input=None,
analysis_config=None,
output_s3_uri=None,
constraints=None,
monitor_schedule_name=None,
schedule_cron_expression=None,
enable_cloudwatch_metrics=True,
batch_transform_input=None,
):
"""Creates a monitoring schedule.

Args:
endpoint_input (str or sagemaker.model_monitor.EndpointInput): The endpoint to monitor.
This can either be the endpoint name or an EndpointInput.
ground_truth_input (str): S3 URI to ground truth dataset.
This can either be the endpoint name or an EndpointInput. (default: None)
ground_truth_input (str): S3 URI to ground truth dataset. (default: None)
analysis_config (str or BiasAnalysisConfig): URI to analysis_config for the bias job.
If it is None then configuration of the latest baselining job will be reused, but
if no baselining job then fail the call.
if no baselining job then fail the call. (default: None)
output_s3_uri (str): S3 destination of the constraint_violations and analysis result.
Default: "s3://<default_session_bucket>/<job_name>/output"
Default: "s3://<default_session_bucket>/<job_name>/output" (default: None)
constraints (sagemaker.model_monitor.Constraints or str): If provided it will be used
for monitoring the endpoint. It can be a Constraints object or an S3 uri pointing
to a constraints JSON file.
to a constraints JSON file. (default: None)
monitor_schedule_name (str): Schedule name. If not specified, the processor generates
a default job name, based on the image name and current timestamp.
(default: None)
schedule_cron_expression (str): The cron expression that dictates the frequency that
this job run. See sagemaker.model_monitor.CronExpressionGenerator for valid
expressions. Default: Daily.
expressions. Default: Daily. (default: None)
enable_cloudwatch_metrics (bool): Whether to publish cloudwatch metrics as part of
the baselining or monitoring jobs.
the baselining or monitoring jobs. (default: True)
batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to run
the monitoring schedule on the batch transform (default: None)
"""
# we default ground_truth_input to None in the function signature
# but verify they are giving here for positional argument
# backward compatibility reason.
if not ground_truth_input:
raise ValueError("ground_truth_input can not be None.")
if self.job_definition_name is not None or self.monitoring_schedule_name is not None:
message = (
"It seems that this object was already used to create an Amazon Model "
Expand All @@ -540,6 +573,15 @@ def create_monitoring_schedule(
_LOGGER.error(message)
raise ValueError(message)

if (batch_transform_input is not None) ^ (endpoint_input is None):
message = (
"Need to have either batch_transform_input or endpoint_input to create an "
"Amazon Model Monitoring Schedule. "
"Please provide only one of the above required inputs"
)
_LOGGER.error(message)
raise ValueError(message)

# create job definition
monitor_schedule_name = self._generate_monitoring_schedule_name(
schedule_name=monitor_schedule_name
Expand Down Expand Up @@ -569,6 +611,7 @@ def create_monitoring_schedule(
env=self.env,
tags=self.tags,
network_config=self.network_config,
batch_transform_input=batch_transform_input,
)
self.sagemaker_session.sagemaker_client.create_model_bias_job_definition(**request_dict)

Expand Down Expand Up @@ -612,6 +655,7 @@ def update_monitoring_schedule(
max_runtime_in_seconds=None,
env=None,
network_config=None,
batch_transform_input=None,
):
"""Updates the existing monitoring schedule.

Expand Down Expand Up @@ -651,6 +695,8 @@ def update_monitoring_schedule(
network_config (sagemaker.network.NetworkConfig): A NetworkConfig
object that configures network isolation, encryption of
inter-container traffic, security group IDs, and subnets.
batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to run
the monitoring schedule on the batch transform
"""
valid_args = {
arg: value for arg, value in locals().items() if arg != "self" and value is not None
Expand All @@ -660,6 +706,15 @@ def update_monitoring_schedule(
if len(valid_args) <= 0:
return

if batch_transform_input is not None and endpoint_input is not None:
message = (
"Need to have either batch_transform_input or endpoint_input to create an "
"Amazon Model Monitoring Schedule. "
"Please provide only one of the above required inputs"
)
_LOGGER.error(message)
raise ValueError(message)

# Only need to update schedule expression
if len(valid_args) == 1 and schedule_cron_expression is not None:
self._update_monitoring_schedule(self.job_definition_name, schedule_cron_expression)
Expand Down Expand Up @@ -691,6 +746,7 @@ def update_monitoring_schedule(
env=env,
tags=self.tags,
network_config=network_config,
batch_transform_input=batch_transform_input,
)
self.sagemaker_session.sagemaker_client.create_model_bias_job_definition(**request_dict)
try:
Expand Down Expand Up @@ -895,19 +951,20 @@ def suggest_baseline(
# noinspection PyMethodOverriding
def create_monitoring_schedule(
self,
endpoint_input,
endpoint_input=None,
analysis_config=None,
output_s3_uri=None,
constraints=None,
monitor_schedule_name=None,
schedule_cron_expression=None,
enable_cloudwatch_metrics=True,
batch_transform_input=None,
):
"""Creates a monitoring schedule.

Args:
endpoint_input (str or sagemaker.model_monitor.EndpointInput): The endpoint to monitor.
This can either be the endpoint name or an EndpointInput.
This can either be the endpoint name or an EndpointInput. (default: None)
analysis_config (str or ExplainabilityAnalysisConfig): URI to the analysis_config for
the explainability job. If it is None then configuration of the latest baselining
job will be reused, but if no baselining job then fail the call.
Expand All @@ -923,6 +980,8 @@ def create_monitoring_schedule(
expressions. Default: Daily.
enable_cloudwatch_metrics (bool): Whether to publish cloudwatch metrics as part of
the baselining or monitoring jobs.
batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to
run the monitoring schedule on the batch transform
"""
if self.job_definition_name is not None or self.monitoring_schedule_name is not None:
message = (
Expand All @@ -933,6 +992,15 @@ def create_monitoring_schedule(
_LOGGER.error(message)
raise ValueError(message)

if (batch_transform_input is not None) ^ (endpoint_input is None):
message = (
"Need to have either batch_transform_input or endpoint_input to create an "
"Amazon Model Monitoring Schedule."
"Please provide only one of the above required inputs"
)
_LOGGER.error(message)
raise ValueError(message)

# create job definition
monitor_schedule_name = self._generate_monitoring_schedule_name(
schedule_name=monitor_schedule_name
Expand Down Expand Up @@ -961,6 +1029,7 @@ def create_monitoring_schedule(
env=self.env,
tags=self.tags,
network_config=self.network_config,
batch_transform_input=batch_transform_input,
)
self.sagemaker_session.sagemaker_client.create_model_explainability_job_definition(
**request_dict
Expand Down Expand Up @@ -1005,6 +1074,7 @@ def update_monitoring_schedule(
max_runtime_in_seconds=None,
env=None,
network_config=None,
batch_transform_input=None,
):
"""Updates the existing monitoring schedule.

Expand Down Expand Up @@ -1043,6 +1113,8 @@ def update_monitoring_schedule(
network_config (sagemaker.network.NetworkConfig): A NetworkConfig
object that configures network isolation, encryption of
inter-container traffic, security group IDs, and subnets.
batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to
run the monitoring schedule on the batch transform
"""
valid_args = {
arg: value for arg, value in locals().items() if arg != "self" and value is not None
Expand All @@ -1052,6 +1124,15 @@ def update_monitoring_schedule(
if len(valid_args) <= 0:
raise ValueError("Nothing to update.")

if batch_transform_input is not None and endpoint_input is not None:
message = (
"Need to have either batch_transform_input or endpoint_input to create an "
"Amazon Model Monitoring Schedule. "
"Please provide only one of the above required inputs"
)
_LOGGER.error(message)
raise ValueError(message)

# Only need to update schedule expression
if len(valid_args) == 1 and schedule_cron_expression is not None:
self._update_monitoring_schedule(self.job_definition_name, schedule_cron_expression)
Expand Down Expand Up @@ -1084,6 +1165,7 @@ def update_monitoring_schedule(
env=env,
tags=self.tags,
network_config=network_config,
batch_transform_input=batch_transform_input,
)
self.sagemaker_session.sagemaker_client.create_model_explainability_job_definition(
**request_dict
Expand Down
43 changes: 43 additions & 0 deletions src/sagemaker/model_monitor/dataset_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,46 @@ def sagemaker_capture_json():

"""
return {"sagemakerCaptureJson": {}}


class MonitoringDatasetFormat(object):
"""Represents a Dataset Format that is used when calling a DefaultModelMonitor."""

@staticmethod
def csv(header=True):
"""Returns a DatasetFormat JSON string for use with a DefaultModelMonitor.

Args:
header (bool): Whether the csv dataset to baseline and monitor has a header.
Default: True.
output_columns_position (str): The position of the output columns.
Must be one of ("START", "END"). Default: "START".

Returns:
dict: JSON string containing DatasetFormat to be used by DefaultModelMonitor.

"""
return {"Csv": {"Header": header}}

@staticmethod
def json(lines=True):
"""Returns a DatasetFormat JSON string for use with a DefaultModelMonitor.

Args:
lines (bool): Whether the file should be read as a json object per line. Default: True.

Returns:
dict: JSON string containing DatasetFormat to be used by DefaultModelMonitor.

"""
return {"Json": {"Line": lines}}

@staticmethod
def parquet():
"""Returns a DatasetFormat SageMaker Capture Json string for use with a DefaultModelMonitor.

Returns:
dict: JSON string containing DatasetFormat to be used by DefaultModelMonitor.

"""
return {"Parquet": {}}
Loading