Skip to content

Commit 8ae7572

Browse files
author
Keshav Chandak
committed
Fix: updating batch transform job in monitoring schedule
1 parent c4c4e83 commit 8ae7572

File tree

3 files changed

+280
-0
lines changed

3 files changed

+280
-0
lines changed

src/sagemaker/model_monitor/model_monitoring.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,7 @@ def update_monitoring_schedule(
433433
network_config=None,
434434
role=None,
435435
image_uri=None,
436+
batch_transform_input=None,
436437
):
437438
"""Updates the existing monitoring schedule.
438439
@@ -475,14 +476,29 @@ def update_monitoring_schedule(
475476
role (str): An AWS IAM role name or ARN. The Amazon SageMaker jobs use this role.
476477
image_uri (str): The uri of the image to use for the jobs started by
477478
the Monitor.
479+
batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to
480+
run the monitoring schedule on the batch transform (default: None)
478481
479482
"""
480483
monitoring_inputs = None
484+
485+
if (batch_transform_input is not None) and (endpoint_input is not None):
486+
message = (
487+
"Cannot update both batch_transform_input and endpoint_input to update an "
488+
"Amazon Model Monitoring Schedule. "
489+
"Please provide atmost one of the above required inputs"
490+
)
491+
_LOGGER.error(message)
492+
raise ValueError(message)
493+
481494
if endpoint_input is not None:
482495
monitoring_inputs = [
483496
self._normalize_endpoint_input(endpoint_input=endpoint_input)._to_request_dict()
484497
]
485498

499+
elif batch_transform_input is not None:
500+
monitoring_inputs = [batch_transform_input._to_request_dict()]
501+
486502
monitoring_output_config = None
487503
if output is not None:
488504
normalized_monitoring_output = self._normalize_monitoring_output_fields(output=output)
@@ -1895,6 +1911,7 @@ def update_monitoring_schedule(
18951911
network_config=None,
18961912
enable_cloudwatch_metrics=None,
18971913
role=None,
1914+
batch_transform_input=None,
18981915
):
18991916
"""Updates the existing monitoring schedule.
19001917
@@ -1936,8 +1953,20 @@ def update_monitoring_schedule(
19361953
enable_cloudwatch_metrics (bool): Whether to publish cloudwatch metrics as part of
19371954
the baselining or monitoring jobs.
19381955
role (str): An AWS IAM role name or ARN. The Amazon SageMaker jobs use this role.
1956+
batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to
1957+
run the monitoring schedule on the batch transform (default: None)
19391958
19401959
"""
1960+
1961+
if (batch_transform_input is not None) and (endpoint_input is not None):
1962+
message = (
1963+
"Cannot update both batch_transform_input and endpoint_input to update an "
1964+
"Amazon Model Monitoring Schedule. "
1965+
"Please provide atmost one of the above required inputs"
1966+
)
1967+
_LOGGER.error(message)
1968+
raise ValueError(message)
1969+
19411970
# check if this schedule is in v2 format and update as per v2 format if it is
19421971
if self.job_definition_name is not None:
19431972
self._update_data_quality_monitoring_schedule(
@@ -1958,13 +1987,17 @@ def update_monitoring_schedule(
19581987
network_config=network_config,
19591988
enable_cloudwatch_metrics=enable_cloudwatch_metrics,
19601989
role=role,
1990+
batch_transform_input=batch_transform_input,
19611991
)
19621992
return
19631993

19641994
monitoring_inputs = None
19651995
if endpoint_input is not None:
19661996
monitoring_inputs = [self._normalize_endpoint_input(endpoint_input)._to_request_dict()]
19671997

1998+
elif batch_transform_input is not None:
1999+
monitoring_inputs = [batch_transform_input._to_request_dict()]
2000+
19682001
record_preprocessor_script_s3_uri = None
19692002
if record_preprocessor_script is not None:
19702003
record_preprocessor_script_s3_uri = self._s3_uri_from_local_path(
@@ -3022,6 +3055,15 @@ def update_monitoring_schedule(
30223055
self._update_monitoring_schedule(self.job_definition_name, schedule_cron_expression)
30233056
return
30243057

3058+
if (batch_transform_input is not None) and (endpoint_input is not None):
3059+
message = (
3060+
"Cannot update both batch_transform_input and endpoint_input to update an "
3061+
"Amazon Model Monitoring Schedule. "
3062+
"Please provide atmost one of the above required inputs"
3063+
)
3064+
_LOGGER.error(message)
3065+
raise ValueError(message)
3066+
30253067
# Need to update schedule with a new job definition
30263068
job_desc = self.sagemaker_session.sagemaker_client.describe_model_quality_job_definition(
30273069
JobDefinitionName=self.job_definition_name

tests/integ/test_model_monitor.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1852,3 +1852,197 @@ def _verify_default_monitoring_schedule_with_batch_transform(
18521852
)
18531853
else:
18541854
assert network_config is None
1855+
1856+
1857+
@pytest.mark.release
1858+
def test_default_update_monitoring_batch_transform(
1859+
sagemaker_session, output_kms_key, volume_kms_key
1860+
):
1861+
my_default_monitor = DefaultModelMonitor(
1862+
role=ROLE,
1863+
instance_count=INSTANCE_COUNT,
1864+
instance_type=INSTANCE_TYPE,
1865+
volume_size_in_gb=VOLUME_SIZE_IN_GB,
1866+
volume_kms_key=volume_kms_key,
1867+
output_kms_key=output_kms_key,
1868+
max_runtime_in_seconds=MAX_RUNTIME_IN_SECONDS,
1869+
sagemaker_session=sagemaker_session,
1870+
env=ENVIRONMENT,
1871+
tags=TAGS,
1872+
network_config=NETWORK_CONFIG,
1873+
)
1874+
1875+
output_s3_uri = os.path.join(
1876+
"s3://",
1877+
sagemaker_session.default_bucket(),
1878+
"integ-test-monitoring-output-bucket",
1879+
str(uuid.uuid4()),
1880+
)
1881+
1882+
data_captured_destination_s3_uri = os.path.join(
1883+
"s3://",
1884+
sagemaker_session.default_bucket(),
1885+
"sagemaker-serving-batch-transform",
1886+
str(uuid.uuid4()),
1887+
)
1888+
1889+
batch_transform_input = BatchTransformInput(
1890+
data_captured_destination_s3_uri=data_captured_destination_s3_uri,
1891+
destination="/opt/ml/processing/output",
1892+
dataset_format=MonitoringDatasetFormat.csv(header=False),
1893+
)
1894+
1895+
statistics = Statistics.from_file_path(
1896+
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json"),
1897+
sagemaker_session=sagemaker_session,
1898+
)
1899+
1900+
constraints = Constraints.from_file_path(
1901+
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json"),
1902+
sagemaker_session=sagemaker_session,
1903+
)
1904+
1905+
my_default_monitor.create_monitoring_schedule(
1906+
batch_transform_input=batch_transform_input,
1907+
output_s3_uri=output_s3_uri,
1908+
statistics=statistics,
1909+
constraints=constraints,
1910+
schedule_cron_expression=HOURLY_CRON_EXPRESSION,
1911+
enable_cloudwatch_metrics=ENABLE_CLOUDWATCH_METRICS,
1912+
)
1913+
1914+
_wait_for_schedule_changes_to_apply(monitor=my_default_monitor)
1915+
1916+
data_captured_destination_s3_uri = os.path.join(
1917+
"s3://",
1918+
sagemaker_session.default_bucket(),
1919+
"sagemaker-tensorflow-serving-batch-transform",
1920+
str(uuid.uuid4()),
1921+
)
1922+
1923+
batch_transform_input = BatchTransformInput(
1924+
data_captured_destination_s3_uri=data_captured_destination_s3_uri,
1925+
destination="/opt/ml/processing/output",
1926+
dataset_format=MonitoringDatasetFormat.csv(header=False),
1927+
)
1928+
1929+
my_default_monitor.update_monitoring_schedule(
1930+
batch_transform_input=batch_transform_input,
1931+
)
1932+
1933+
_wait_for_schedule_changes_to_apply(monitor=my_default_monitor)
1934+
1935+
schedule_description = my_default_monitor.describe_schedule()
1936+
1937+
_verify_default_monitoring_schedule_with_batch_transform(
1938+
sagemaker_session=sagemaker_session,
1939+
schedule_description=schedule_description,
1940+
cron_expression=HOURLY_CRON_EXPRESSION,
1941+
statistics=statistics,
1942+
constraints=constraints,
1943+
output_kms_key=output_kms_key,
1944+
volume_kms_key=volume_kms_key,
1945+
network_config=NETWORK_CONFIG,
1946+
)
1947+
1948+
my_default_monitor.stop_monitoring_schedule()
1949+
my_default_monitor.delete_monitoring_schedule()
1950+
1951+
1952+
@pytest.mark.release
1953+
def test_byoc_monitoring_schedule_name_update_batch(
1954+
sagemaker_session, output_kms_key, volume_kms_key
1955+
):
1956+
byoc_env = ENVIRONMENT.copy()
1957+
byoc_env["dataset_format"] = json.dumps(DatasetFormat.csv(header=False))
1958+
byoc_env["dataset_source"] = "/opt/ml/processing/input/baseline_dataset_input"
1959+
byoc_env["output_path"] = os.path.join("/opt/ml/processing/output")
1960+
byoc_env["publish_cloudwatch_metrics"] = "Disabled"
1961+
1962+
my_byoc_monitor = ModelMonitor(
1963+
role=ROLE,
1964+
image_uri=DefaultModelMonitor._get_default_image_uri(
1965+
sagemaker_session.boto_session.region_name
1966+
),
1967+
instance_count=INSTANCE_COUNT,
1968+
instance_type=INSTANCE_TYPE,
1969+
volume_size_in_gb=VOLUME_SIZE_IN_GB,
1970+
volume_kms_key=volume_kms_key,
1971+
output_kms_key=output_kms_key,
1972+
max_runtime_in_seconds=MAX_RUNTIME_IN_SECONDS,
1973+
sagemaker_session=sagemaker_session,
1974+
env=byoc_env,
1975+
tags=TAGS,
1976+
network_config=NETWORK_CONFIG,
1977+
)
1978+
1979+
output_s3_uri = os.path.join(
1980+
"s3://",
1981+
sagemaker_session.default_bucket(),
1982+
"integ-test-monitoring-output-bucket",
1983+
str(uuid.uuid4()),
1984+
)
1985+
1986+
statistics = Statistics.from_file_path(
1987+
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json"),
1988+
sagemaker_session=sagemaker_session,
1989+
)
1990+
1991+
constraints = Constraints.from_file_path(
1992+
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json"),
1993+
sagemaker_session=sagemaker_session,
1994+
)
1995+
1996+
data_captured_destination_s3_uri = os.path.join(
1997+
"s3://",
1998+
sagemaker_session.default_bucket(),
1999+
"sagemaker-serving-batch-transform",
2000+
str(uuid.uuid4()),
2001+
)
2002+
2003+
batch_transform_input = BatchTransformInput(
2004+
data_captured_destination_s3_uri=data_captured_destination_s3_uri,
2005+
destination="/opt/ml/processing/output",
2006+
dataset_format=MonitoringDatasetFormat.csv(header=False),
2007+
)
2008+
2009+
my_byoc_monitor.create_monitoring_schedule(
2010+
endpoint_input=batch_transform_input,
2011+
output=MonitoringOutput(source="/opt/ml/processing/output", destination=output_s3_uri),
2012+
statistics=statistics,
2013+
constraints=constraints,
2014+
schedule_cron_expression=HOURLY_CRON_EXPRESSION,
2015+
)
2016+
2017+
_wait_for_schedule_changes_to_apply(monitor=my_byoc_monitor)
2018+
2019+
data_captured_destination_s3_uri = os.path.join(
2020+
"s3://",
2021+
sagemaker_session.default_bucket(),
2022+
"sagemaker-tensorflow-serving-batch-transform",
2023+
str(uuid.uuid4()),
2024+
)
2025+
2026+
batch_transform_input = BatchTransformInput(
2027+
data_captured_destination_s3_uri=data_captured_destination_s3_uri,
2028+
destination="/opt/ml/processing/output",
2029+
dataset_format=MonitoringDatasetFormat.csv(header=False),
2030+
)
2031+
2032+
my_byoc_monitor.update_monitoring_schedule(
2033+
batch_transform_input=batch_transform_input,
2034+
)
2035+
2036+
_wait_for_schedule_changes_to_apply(monitor=my_byoc_monitor)
2037+
2038+
schedule_description = my_byoc_monitor.describe_schedule()
2039+
2040+
assert (
2041+
data_captured_destination_s3_uri
2042+
== schedule_description["MonitoringScheduleConfig"]["MonitoringJobDefinition"][
2043+
"MonitoringInputs"
2044+
][0]["BatchTransformInput"]["DataCapturedDestinationS3Uri"]
2045+
)
2046+
2047+
my_byoc_monitor.stop_monitoring_schedule()
2048+
my_byoc_monitor.delete_monitoring_schedule()

tests/unit/sagemaker/monitor/test_model_monitoring.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1836,3 +1836,47 @@ def test_model_monitor_with_arguments(
18361836
role_arn=ROLE,
18371837
tags=TAGS,
18381838
)
1839+
1840+
1841+
def test_update_model_monitor_error_with_endpoint_and_batch(
1842+
model_monitor_arguments,
1843+
data_quality_monitor,
1844+
endpoint_input=EndpointInput(
1845+
endpoint_name=ENDPOINT_NAME,
1846+
destination=ENDPOINT_INPUT_LOCAL_PATH,
1847+
start_time_offset=START_TIME_OFFSET,
1848+
end_time_offset=END_TIME_OFFSET,
1849+
features_attribute=FEATURES_ATTRIBUTE,
1850+
inference_attribute=INFERENCE_ATTRIBUTE,
1851+
probability_attribute=PROBABILITY_ATTRIBUTE,
1852+
probability_threshold_attribute=PROBABILITY_THRESHOLD_ATTRIBUTE,
1853+
),
1854+
batch_transform_input=BatchTransformInput(
1855+
data_captured_destination_s3_uri=DATA_CAPTURED_S3_URI,
1856+
destination=SCHEDULE_DESTINATION,
1857+
dataset_format=MonitoringDatasetFormat.csv(header=False),
1858+
),
1859+
):
1860+
try:
1861+
model_monitor_arguments.update_monitoring_schedule(
1862+
schedule_cron_expression=CRON_HOURLY,
1863+
endpoint_input=endpoint_input,
1864+
arguments=ARGUMENTS,
1865+
output=MonitoringOutput(source="/opt/ml/processing/output", destination=OUTPUT_S3_URI),
1866+
batch_transform_input=batch_transform_input,
1867+
)
1868+
except ValueError as error:
1869+
assert "Cannot update both batch_transform_input and endpoint_input to update an" in str(
1870+
error
1871+
)
1872+
1873+
try:
1874+
data_quality_monitor.update_monitoring_schedule(
1875+
schedule_cron_expression=CRON_HOURLY,
1876+
endpoint_input=endpoint_input,
1877+
batch_transform_input=batch_transform_input,
1878+
)
1879+
except ValueError as error:
1880+
assert "Cannot update both batch_transform_input and endpoint_input to update an" in str(
1881+
error
1882+
)

0 commit comments

Comments
 (0)