Skip to content

Commit 1d3eec1

Browse files
author
Keshav Chandak
committed
Fix: updating batch transform job in monitoring schedule
1 parent c4c4e83 commit 1d3eec1

File tree

3 files changed

+279
-1
lines changed

3 files changed

+279
-1
lines changed

src/sagemaker/model_monitor/model_monitoring.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,7 @@ def update_monitoring_schedule(
433433
network_config=None,
434434
role=None,
435435
image_uri=None,
436+
batch_transform_input=None,
436437
):
437438
"""Updates the existing monitoring schedule.
438439
@@ -475,14 +476,29 @@ def update_monitoring_schedule(
475476
role (str): An AWS IAM role name or ARN. The Amazon SageMaker jobs use this role.
476477
image_uri (str): The uri of the image to use for the jobs started by
477478
the Monitor.
479+
batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to
480+
run the monitoring schedule on the batch transform (default: None)
478481
479482
"""
480483
monitoring_inputs = None
484+
485+
if (batch_transform_input is not None) and (endpoint_input is not None):
486+
message = (
487+
"Cannot update both batch_transform_input and endpoint_input to update an "
488+
"Amazon Model Monitoring Schedule. "
489+
"Please provide atmost one of the above required inputs"
490+
)
491+
_LOGGER.error(message)
492+
raise ValueError(message)
493+
481494
if endpoint_input is not None:
482495
monitoring_inputs = [
483496
self._normalize_endpoint_input(endpoint_input=endpoint_input)._to_request_dict()
484497
]
485498

499+
elif batch_transform_input is not None:
500+
monitoring_inputs = [batch_transform_input._to_request_dict()]
501+
486502
monitoring_output_config = None
487503
if output is not None:
488504
normalized_monitoring_output = self._normalize_monitoring_output_fields(output=output)
@@ -1895,6 +1911,7 @@ def update_monitoring_schedule(
18951911
network_config=None,
18961912
enable_cloudwatch_metrics=None,
18971913
role=None,
1914+
batch_transform_input=None,
18981915
):
18991916
"""Updates the existing monitoring schedule.
19001917
@@ -1936,8 +1953,20 @@ def update_monitoring_schedule(
19361953
enable_cloudwatch_metrics (bool): Whether to publish cloudwatch metrics as part of
19371954
the baselining or monitoring jobs.
19381955
role (str): An AWS IAM role name or ARN. The Amazon SageMaker jobs use this role.
1956+
batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to
1957+
run the monitoring schedule on the batch transform (default: None)
19391958
19401959
"""
1960+
1961+
if (batch_transform_input is not None) and (endpoint_input is not None):
1962+
message = (
1963+
"Cannot update both batch_transform_input and endpoint_input to update an "
1964+
"Amazon Model Monitoring Schedule. "
1965+
"Please provide atmost one of the above required inputs"
1966+
)
1967+
_LOGGER.error(message)
1968+
raise ValueError(message)
1969+
19411970
# check if this schedule is in v2 format and update as per v2 format if it is
19421971
if self.job_definition_name is not None:
19431972
self._update_data_quality_monitoring_schedule(
@@ -1958,13 +1987,17 @@ def update_monitoring_schedule(
19581987
network_config=network_config,
19591988
enable_cloudwatch_metrics=enable_cloudwatch_metrics,
19601989
role=role,
1990+
batch_transform_input=batch_transform_input,
19611991
)
19621992
return
19631993

19641994
monitoring_inputs = None
19651995
if endpoint_input is not None:
19661996
monitoring_inputs = [self._normalize_endpoint_input(endpoint_input)._to_request_dict()]
19671997

1998+
elif batch_transform_input is not None:
1999+
monitoring_inputs = [batch_transform_input._to_request_dict()]
2000+
19682001
record_preprocessor_script_s3_uri = None
19692002
if record_preprocessor_script is not None:
19702003
record_preprocessor_script_s3_uri = self._s3_uri_from_local_path(
@@ -3022,6 +3055,15 @@ def update_monitoring_schedule(
30223055
self._update_monitoring_schedule(self.job_definition_name, schedule_cron_expression)
30233056
return
30243057

3058+
if (batch_transform_input is not None) and (endpoint_input is not None):
3059+
message = (
3060+
"Cannot update both batch_transform_input and endpoint_input to update an "
3061+
"Amazon Model Monitoring Schedule. "
3062+
"Please provide atmost one of the above required inputs"
3063+
)
3064+
_LOGGER.error(message)
3065+
raise ValueError(message)
3066+
30253067
# Need to update schedule with a new job definition
30263068
job_desc = self.sagemaker_session.sagemaker_client.describe_model_quality_job_definition(
30273069
JobDefinitionName=self.job_definition_name

tests/integ/test_model_monitor.py

Lines changed: 193 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,6 @@ def updated_output_kms_key(sagemaker_session):
281281
tests.integ.test_region() in tests.integ.NO_MODEL_MONITORING_REGIONS,
282282
reason="ModelMonitoring is not yet supported in this region.",
283283
)
284-
@pytest.mark.release
285284
def test_default_monitoring_batch_transform_schedule_name(
286285
sagemaker_session, output_kms_key, volume_kms_key
287286
):
@@ -1852,3 +1851,196 @@ def _verify_default_monitoring_schedule_with_batch_transform(
18521851
)
18531852
else:
18541853
assert network_config is None
1854+
1855+
1856+
def test_default_update_monitoring_batch_transform(
1857+
sagemaker_session, output_kms_key, volume_kms_key
1858+
):
1859+
my_default_monitor = DefaultModelMonitor(
1860+
role=ROLE,
1861+
instance_count=INSTANCE_COUNT,
1862+
instance_type=INSTANCE_TYPE,
1863+
volume_size_in_gb=VOLUME_SIZE_IN_GB,
1864+
volume_kms_key=volume_kms_key,
1865+
output_kms_key=output_kms_key,
1866+
max_runtime_in_seconds=MAX_RUNTIME_IN_SECONDS,
1867+
sagemaker_session=sagemaker_session,
1868+
env=ENVIRONMENT,
1869+
tags=TAGS,
1870+
network_config=NETWORK_CONFIG,
1871+
)
1872+
1873+
output_s3_uri = os.path.join(
1874+
"s3://",
1875+
sagemaker_session.default_bucket(),
1876+
"integ-test-monitoring-output-bucket",
1877+
str(uuid.uuid4()),
1878+
)
1879+
1880+
data_captured_destination_s3_uri = os.path.join(
1881+
"s3://",
1882+
sagemaker_session.default_bucket(),
1883+
"sagemaker-serving-batch-transform",
1884+
str(uuid.uuid4()),
1885+
)
1886+
1887+
batch_transform_input = BatchTransformInput(
1888+
data_captured_destination_s3_uri=data_captured_destination_s3_uri,
1889+
destination="/opt/ml/processing/output",
1890+
dataset_format=MonitoringDatasetFormat.csv(header=False),
1891+
)
1892+
1893+
statistics = Statistics.from_file_path(
1894+
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json"),
1895+
sagemaker_session=sagemaker_session,
1896+
)
1897+
1898+
constraints = Constraints.from_file_path(
1899+
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json"),
1900+
sagemaker_session=sagemaker_session,
1901+
)
1902+
1903+
my_default_monitor.create_monitoring_schedule(
1904+
batch_transform_input=batch_transform_input,
1905+
output_s3_uri=output_s3_uri,
1906+
statistics=statistics,
1907+
constraints=constraints,
1908+
schedule_cron_expression=HOURLY_CRON_EXPRESSION,
1909+
enable_cloudwatch_metrics=ENABLE_CLOUDWATCH_METRICS,
1910+
)
1911+
1912+
_wait_for_schedule_changes_to_apply(monitor=my_default_monitor)
1913+
1914+
data_captured_destination_s3_uri = os.path.join(
1915+
"s3://",
1916+
sagemaker_session.default_bucket(),
1917+
"sagemaker-tensorflow-serving-batch-transform",
1918+
str(uuid.uuid4()),
1919+
)
1920+
1921+
batch_transform_input = BatchTransformInput(
1922+
data_captured_destination_s3_uri=data_captured_destination_s3_uri,
1923+
destination="/opt/ml/processing/output",
1924+
dataset_format=MonitoringDatasetFormat.csv(header=False),
1925+
)
1926+
1927+
my_default_monitor.update_monitoring_schedule(
1928+
batch_transform_input=batch_transform_input,
1929+
)
1930+
1931+
_wait_for_schedule_changes_to_apply(monitor=my_default_monitor)
1932+
1933+
schedule_description = my_default_monitor.describe_schedule()
1934+
1935+
_verify_default_monitoring_schedule_with_batch_transform(
1936+
sagemaker_session=sagemaker_session,
1937+
schedule_description=schedule_description,
1938+
cron_expression=HOURLY_CRON_EXPRESSION,
1939+
statistics=statistics,
1940+
constraints=constraints,
1941+
output_kms_key=output_kms_key,
1942+
volume_kms_key=volume_kms_key,
1943+
network_config=NETWORK_CONFIG,
1944+
)
1945+
1946+
my_default_monitor.stop_monitoring_schedule()
1947+
my_default_monitor.delete_monitoring_schedule()
1948+
1949+
1950+
@pytest.mark.release
1951+
def test_byoc_monitoring_schedule_name_update_batch(
1952+
sagemaker_session, output_kms_key, volume_kms_key
1953+
):
1954+
byoc_env = ENVIRONMENT.copy()
1955+
byoc_env["dataset_format"] = json.dumps(DatasetFormat.csv(header=False))
1956+
byoc_env["dataset_source"] = "/opt/ml/processing/input/baseline_dataset_input"
1957+
byoc_env["output_path"] = os.path.join("/opt/ml/processing/output")
1958+
byoc_env["publish_cloudwatch_metrics"] = "Disabled"
1959+
1960+
my_byoc_monitor = ModelMonitor(
1961+
role=ROLE,
1962+
image_uri=DefaultModelMonitor._get_default_image_uri(
1963+
sagemaker_session.boto_session.region_name
1964+
),
1965+
instance_count=INSTANCE_COUNT,
1966+
instance_type=INSTANCE_TYPE,
1967+
volume_size_in_gb=VOLUME_SIZE_IN_GB,
1968+
volume_kms_key=volume_kms_key,
1969+
output_kms_key=output_kms_key,
1970+
max_runtime_in_seconds=MAX_RUNTIME_IN_SECONDS,
1971+
sagemaker_session=sagemaker_session,
1972+
env=byoc_env,
1973+
tags=TAGS,
1974+
network_config=NETWORK_CONFIG,
1975+
)
1976+
1977+
output_s3_uri = os.path.join(
1978+
"s3://",
1979+
sagemaker_session.default_bucket(),
1980+
"integ-test-monitoring-output-bucket",
1981+
str(uuid.uuid4()),
1982+
)
1983+
1984+
statistics = Statistics.from_file_path(
1985+
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json"),
1986+
sagemaker_session=sagemaker_session,
1987+
)
1988+
1989+
constraints = Constraints.from_file_path(
1990+
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json"),
1991+
sagemaker_session=sagemaker_session,
1992+
)
1993+
1994+
data_captured_destination_s3_uri = os.path.join(
1995+
"s3://",
1996+
sagemaker_session.default_bucket(),
1997+
"sagemaker-serving-batch-transform",
1998+
str(uuid.uuid4()),
1999+
)
2000+
2001+
batch_transform_input = BatchTransformInput(
2002+
data_captured_destination_s3_uri=data_captured_destination_s3_uri,
2003+
destination="/opt/ml/processing/output",
2004+
dataset_format=MonitoringDatasetFormat.csv(header=False),
2005+
)
2006+
2007+
my_byoc_monitor.create_monitoring_schedule(
2008+
endpoint_input=batch_transform_input,
2009+
output=MonitoringOutput(source="/opt/ml/processing/output", destination=output_s3_uri),
2010+
statistics=statistics,
2011+
constraints=constraints,
2012+
schedule_cron_expression=HOURLY_CRON_EXPRESSION,
2013+
)
2014+
2015+
_wait_for_schedule_changes_to_apply(monitor=my_byoc_monitor)
2016+
2017+
data_captured_destination_s3_uri = os.path.join(
2018+
"s3://",
2019+
sagemaker_session.default_bucket(),
2020+
"sagemaker-tensorflow-serving-batch-transform",
2021+
str(uuid.uuid4()),
2022+
)
2023+
2024+
batch_transform_input = BatchTransformInput(
2025+
data_captured_destination_s3_uri=data_captured_destination_s3_uri,
2026+
destination="/opt/ml/processing/output",
2027+
dataset_format=MonitoringDatasetFormat.csv(header=False),
2028+
)
2029+
2030+
my_byoc_monitor.update_monitoring_schedule(
2031+
batch_transform_input=batch_transform_input,
2032+
)
2033+
2034+
_wait_for_schedule_changes_to_apply(monitor=my_byoc_monitor)
2035+
2036+
schedule_description = my_byoc_monitor.describe_schedule()
2037+
2038+
assert (
2039+
data_captured_destination_s3_uri
2040+
== schedule_description["MonitoringScheduleConfig"]["MonitoringJobDefinition"][
2041+
"MonitoringInputs"
2042+
][0]["BatchTransformInput"]["DataCapturedDestinationS3Uri"]
2043+
)
2044+
2045+
my_byoc_monitor.stop_monitoring_schedule()
2046+
my_byoc_monitor.delete_monitoring_schedule()

tests/unit/sagemaker/monitor/test_model_monitoring.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1836,3 +1836,47 @@ def test_model_monitor_with_arguments(
18361836
role_arn=ROLE,
18371837
tags=TAGS,
18381838
)
1839+
1840+
1841+
def test_update_model_monitor_error_with_endpoint_and_batch(
1842+
model_monitor_arguments,
1843+
data_quality_monitor,
1844+
endpoint_input=EndpointInput(
1845+
endpoint_name=ENDPOINT_NAME,
1846+
destination=ENDPOINT_INPUT_LOCAL_PATH,
1847+
start_time_offset=START_TIME_OFFSET,
1848+
end_time_offset=END_TIME_OFFSET,
1849+
features_attribute=FEATURES_ATTRIBUTE,
1850+
inference_attribute=INFERENCE_ATTRIBUTE,
1851+
probability_attribute=PROBABILITY_ATTRIBUTE,
1852+
probability_threshold_attribute=PROBABILITY_THRESHOLD_ATTRIBUTE,
1853+
),
1854+
batch_transform_input=BatchTransformInput(
1855+
data_captured_destination_s3_uri=DATA_CAPTURED_S3_URI,
1856+
destination=SCHEDULE_DESTINATION,
1857+
dataset_format=MonitoringDatasetFormat.csv(header=False),
1858+
),
1859+
):
1860+
try:
1861+
model_monitor_arguments.update_monitoring_schedule(
1862+
schedule_cron_expression=CRON_HOURLY,
1863+
endpoint_input=endpoint_input,
1864+
arguments=ARGUMENTS,
1865+
output=MonitoringOutput(source="/opt/ml/processing/output", destination=OUTPUT_S3_URI),
1866+
batch_transform_input=batch_transform_input,
1867+
)
1868+
except ValueError as error:
1869+
assert "Cannot update both batch_transform_input and endpoint_input to update an" in str(
1870+
error
1871+
)
1872+
1873+
try:
1874+
data_quality_monitor.update_monitoring_schedule(
1875+
schedule_cron_expression=CRON_HOURLY,
1876+
endpoint_input=endpoint_input,
1877+
batch_transform_input=batch_transform_input,
1878+
)
1879+
except ValueError as error:
1880+
assert "Cannot update both batch_transform_input and endpoint_input to update an" in str(
1881+
error
1882+
)

0 commit comments

Comments
 (0)