Skip to content

fix:updating batch transform job in monitoring schedule #3767

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions src/sagemaker/model_monitor/model_monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,7 @@ def update_monitoring_schedule(
network_config=None,
role=None,
image_uri=None,
batch_transform_input=None,
):
"""Updates the existing monitoring schedule.

Expand Down Expand Up @@ -475,14 +476,29 @@ def update_monitoring_schedule(
role (str): An AWS IAM role name or ARN. The Amazon SageMaker jobs use this role.
image_uri (str): The uri of the image to use for the jobs started by
the Monitor.
batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to
run the monitoring schedule on the batch transform (default: None)

"""
monitoring_inputs = None

if (batch_transform_input is not None) and (endpoint_input is not None):
message = (
"Cannot update both batch_transform_input and endpoint_input to update an "
"Amazon Model Monitoring Schedule. "
"Please provide atmost one of the above required inputs"
)
_LOGGER.error(message)
raise ValueError(message)

if endpoint_input is not None:
monitoring_inputs = [
self._normalize_endpoint_input(endpoint_input=endpoint_input)._to_request_dict()
]

elif batch_transform_input is not None:
monitoring_inputs = [batch_transform_input._to_request_dict()]

monitoring_output_config = None
if output is not None:
normalized_monitoring_output = self._normalize_monitoring_output_fields(output=output)
Expand Down Expand Up @@ -1895,6 +1911,7 @@ def update_monitoring_schedule(
network_config=None,
enable_cloudwatch_metrics=None,
role=None,
batch_transform_input=None,
):
"""Updates the existing monitoring schedule.

Expand Down Expand Up @@ -1936,8 +1953,20 @@ def update_monitoring_schedule(
enable_cloudwatch_metrics (bool): Whether to publish cloudwatch metrics as part of
the baselining or monitoring jobs.
role (str): An AWS IAM role name or ARN. The Amazon SageMaker jobs use this role.
batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to
run the monitoring schedule on the batch transform (default: None)

"""

if (batch_transform_input is not None) and (endpoint_input is not None):
message = (
"Cannot update both batch_transform_input and endpoint_input to update an "
"Amazon Model Monitoring Schedule. "
"Please provide atmost one of the above required inputs"
)
_LOGGER.error(message)
raise ValueError(message)

# check if this schedule is in v2 format and update as per v2 format if it is
if self.job_definition_name is not None:
self._update_data_quality_monitoring_schedule(
Expand All @@ -1958,13 +1987,17 @@ def update_monitoring_schedule(
network_config=network_config,
enable_cloudwatch_metrics=enable_cloudwatch_metrics,
role=role,
batch_transform_input=batch_transform_input,
)
return

monitoring_inputs = None
if endpoint_input is not None:
monitoring_inputs = [self._normalize_endpoint_input(endpoint_input)._to_request_dict()]

elif batch_transform_input is not None:
monitoring_inputs = [batch_transform_input._to_request_dict()]

record_preprocessor_script_s3_uri = None
if record_preprocessor_script is not None:
record_preprocessor_script_s3_uri = self._s3_uri_from_local_path(
Expand Down Expand Up @@ -3022,6 +3055,15 @@ def update_monitoring_schedule(
self._update_monitoring_schedule(self.job_definition_name, schedule_cron_expression)
return

if (batch_transform_input is not None) and (endpoint_input is not None):
message = (
"Cannot update both batch_transform_input and endpoint_input to update an "
"Amazon Model Monitoring Schedule. "
"Please provide atmost one of the above required inputs"
)
_LOGGER.error(message)
raise ValueError(message)

# Need to update schedule with a new job definition
job_desc = self.sagemaker_session.sagemaker_client.describe_model_quality_job_definition(
JobDefinitionName=self.job_definition_name
Expand Down
194 changes: 192 additions & 2 deletions tests/integ/test_model_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,6 @@ def updated_output_kms_key(sagemaker_session):
tests.integ.test_region() in tests.integ.NO_MODEL_MONITORING_REGIONS,
reason="ModelMonitoring is not yet supported in this region.",
)
@pytest.mark.release
def test_default_monitoring_batch_transform_schedule_name(
sagemaker_session, output_kms_key, volume_kms_key
):
Expand Down Expand Up @@ -359,7 +358,6 @@ def test_default_monitoring_batch_transform_schedule_name(
tests.integ.test_region() in tests.integ.NO_MODEL_MONITORING_REGIONS,
reason="ModelMonitoring is not yet supported in this region.",
)
@pytest.mark.release
def test_default_monitor_suggest_baseline_and_create_monitoring_schedule_with_customizations(
sagemaker_session, output_kms_key, volume_kms_key, predictor
):
Expand Down Expand Up @@ -1852,3 +1850,195 @@ def _verify_default_monitoring_schedule_with_batch_transform(
)
else:
assert network_config is None


def test_default_update_monitoring_batch_transform(
sagemaker_session, output_kms_key, volume_kms_key
):
my_default_monitor = DefaultModelMonitor(
role=ROLE,
instance_count=INSTANCE_COUNT,
instance_type=INSTANCE_TYPE,
volume_size_in_gb=VOLUME_SIZE_IN_GB,
volume_kms_key=volume_kms_key,
output_kms_key=output_kms_key,
max_runtime_in_seconds=MAX_RUNTIME_IN_SECONDS,
sagemaker_session=sagemaker_session,
env=ENVIRONMENT,
tags=TAGS,
network_config=NETWORK_CONFIG,
)

output_s3_uri = os.path.join(
"s3://",
sagemaker_session.default_bucket(),
"integ-test-monitoring-output-bucket",
str(uuid.uuid4()),
)

data_captured_destination_s3_uri = os.path.join(
"s3://",
sagemaker_session.default_bucket(),
"sagemaker-serving-batch-transform",
str(uuid.uuid4()),
)

batch_transform_input = BatchTransformInput(
data_captured_destination_s3_uri=data_captured_destination_s3_uri,
destination="/opt/ml/processing/output",
dataset_format=MonitoringDatasetFormat.csv(header=False),
)

statistics = Statistics.from_file_path(
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json"),
sagemaker_session=sagemaker_session,
)

constraints = Constraints.from_file_path(
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json"),
sagemaker_session=sagemaker_session,
)

my_default_monitor.create_monitoring_schedule(
batch_transform_input=batch_transform_input,
output_s3_uri=output_s3_uri,
statistics=statistics,
constraints=constraints,
schedule_cron_expression=HOURLY_CRON_EXPRESSION,
enable_cloudwatch_metrics=ENABLE_CLOUDWATCH_METRICS,
)

_wait_for_schedule_changes_to_apply(monitor=my_default_monitor)

data_captured_destination_s3_uri = os.path.join(
"s3://",
sagemaker_session.default_bucket(),
"sagemaker-tensorflow-serving-batch-transform",
str(uuid.uuid4()),
)

batch_transform_input = BatchTransformInput(
data_captured_destination_s3_uri=data_captured_destination_s3_uri,
destination="/opt/ml/processing/output",
dataset_format=MonitoringDatasetFormat.csv(header=False),
)

my_default_monitor.update_monitoring_schedule(
batch_transform_input=batch_transform_input,
)

_wait_for_schedule_changes_to_apply(monitor=my_default_monitor)

schedule_description = my_default_monitor.describe_schedule()

_verify_default_monitoring_schedule_with_batch_transform(
sagemaker_session=sagemaker_session,
schedule_description=schedule_description,
cron_expression=HOURLY_CRON_EXPRESSION,
statistics=statistics,
constraints=constraints,
output_kms_key=output_kms_key,
volume_kms_key=volume_kms_key,
network_config=NETWORK_CONFIG,
)

my_default_monitor.stop_monitoring_schedule()
my_default_monitor.delete_monitoring_schedule()


def test_byoc_monitoring_schedule_name_update_batch(
sagemaker_session, output_kms_key, volume_kms_key
):
byoc_env = ENVIRONMENT.copy()
byoc_env["dataset_format"] = json.dumps(DatasetFormat.csv(header=False))
byoc_env["dataset_source"] = "/opt/ml/processing/input/baseline_dataset_input"
byoc_env["output_path"] = os.path.join("/opt/ml/processing/output")
byoc_env["publish_cloudwatch_metrics"] = "Disabled"

my_byoc_monitor = ModelMonitor(
role=ROLE,
image_uri=DefaultModelMonitor._get_default_image_uri(
sagemaker_session.boto_session.region_name
),
instance_count=INSTANCE_COUNT,
instance_type=INSTANCE_TYPE,
volume_size_in_gb=VOLUME_SIZE_IN_GB,
volume_kms_key=volume_kms_key,
output_kms_key=output_kms_key,
max_runtime_in_seconds=MAX_RUNTIME_IN_SECONDS,
sagemaker_session=sagemaker_session,
env=byoc_env,
tags=TAGS,
network_config=NETWORK_CONFIG,
)

output_s3_uri = os.path.join(
"s3://",
sagemaker_session.default_bucket(),
"integ-test-monitoring-output-bucket",
str(uuid.uuid4()),
)

statistics = Statistics.from_file_path(
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json"),
sagemaker_session=sagemaker_session,
)

constraints = Constraints.from_file_path(
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json"),
sagemaker_session=sagemaker_session,
)

data_captured_destination_s3_uri = os.path.join(
"s3://",
sagemaker_session.default_bucket(),
"sagemaker-serving-batch-transform",
str(uuid.uuid4()),
)

batch_transform_input = BatchTransformInput(
data_captured_destination_s3_uri=data_captured_destination_s3_uri,
destination="/opt/ml/processing/output",
dataset_format=MonitoringDatasetFormat.csv(header=False),
)

my_byoc_monitor.create_monitoring_schedule(
endpoint_input=batch_transform_input,
output=MonitoringOutput(source="/opt/ml/processing/output", destination=output_s3_uri),
statistics=statistics,
constraints=constraints,
schedule_cron_expression=HOURLY_CRON_EXPRESSION,
)

_wait_for_schedule_changes_to_apply(monitor=my_byoc_monitor)

data_captured_destination_s3_uri = os.path.join(
"s3://",
sagemaker_session.default_bucket(),
"sagemaker-tensorflow-serving-batch-transform",
str(uuid.uuid4()),
)

batch_transform_input = BatchTransformInput(
data_captured_destination_s3_uri=data_captured_destination_s3_uri,
destination="/opt/ml/processing/output",
dataset_format=MonitoringDatasetFormat.csv(header=False),
)

my_byoc_monitor.update_monitoring_schedule(
batch_transform_input=batch_transform_input,
)

_wait_for_schedule_changes_to_apply(monitor=my_byoc_monitor)

schedule_description = my_byoc_monitor.describe_schedule()

assert (
data_captured_destination_s3_uri
== schedule_description["MonitoringScheduleConfig"]["MonitoringJobDefinition"][
"MonitoringInputs"
][0]["BatchTransformInput"]["DataCapturedDestinationS3Uri"]
)

my_byoc_monitor.stop_monitoring_schedule()
my_byoc_monitor.delete_monitoring_schedule()
44 changes: 44 additions & 0 deletions tests/unit/sagemaker/monitor/test_model_monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -1836,3 +1836,47 @@ def test_model_monitor_with_arguments(
role_arn=ROLE,
tags=TAGS,
)


def test_update_model_monitor_error_with_endpoint_and_batch(
model_monitor_arguments,
data_quality_monitor,
endpoint_input=EndpointInput(
endpoint_name=ENDPOINT_NAME,
destination=ENDPOINT_INPUT_LOCAL_PATH,
start_time_offset=START_TIME_OFFSET,
end_time_offset=END_TIME_OFFSET,
features_attribute=FEATURES_ATTRIBUTE,
inference_attribute=INFERENCE_ATTRIBUTE,
probability_attribute=PROBABILITY_ATTRIBUTE,
probability_threshold_attribute=PROBABILITY_THRESHOLD_ATTRIBUTE,
),
batch_transform_input=BatchTransformInput(
data_captured_destination_s3_uri=DATA_CAPTURED_S3_URI,
destination=SCHEDULE_DESTINATION,
dataset_format=MonitoringDatasetFormat.csv(header=False),
),
):
try:
model_monitor_arguments.update_monitoring_schedule(
schedule_cron_expression=CRON_HOURLY,
endpoint_input=endpoint_input,
arguments=ARGUMENTS,
output=MonitoringOutput(source="/opt/ml/processing/output", destination=OUTPUT_S3_URI),
batch_transform_input=batch_transform_input,
)
except ValueError as error:
assert "Cannot update both batch_transform_input and endpoint_input to update an" in str(
error
)

try:
data_quality_monitor.update_monitoring_schedule(
schedule_cron_expression=CRON_HOURLY,
endpoint_input=endpoint_input,
batch_transform_input=batch_transform_input,
)
except ValueError as error:
assert "Cannot update both batch_transform_input and endpoint_input to update an" in str(
error
)