Skip to content

Commit d986e3e

Browse files
keshav-chandakKeshav Chandak
and
Keshav Chandak
authored
fix:updating batch transform job in monitoring schedule (#3767)
Co-authored-by: Keshav Chandak <[email protected]>
1 parent 60cbb64 commit d986e3e

File tree

3 files changed

+278
-2
lines changed

3 files changed

+278
-2
lines changed

src/sagemaker/model_monitor/model_monitoring.py

+42
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,7 @@ def update_monitoring_schedule(
433433
network_config=None,
434434
role=None,
435435
image_uri=None,
436+
batch_transform_input=None,
436437
):
437438
"""Updates the existing monitoring schedule.
438439
@@ -475,14 +476,29 @@ def update_monitoring_schedule(
475476
role (str): An AWS IAM role name or ARN. The Amazon SageMaker jobs use this role.
476477
image_uri (str): The uri of the image to use for the jobs started by
477478
the Monitor.
479+
batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to
480+
run the monitoring schedule on the batch transform (default: None)
478481
479482
"""
480483
monitoring_inputs = None
484+
485+
if (batch_transform_input is not None) and (endpoint_input is not None):
486+
message = (
487+
"Cannot update both batch_transform_input and endpoint_input to update an "
488+
"Amazon Model Monitoring Schedule. "
489+
"Please provide atmost one of the above required inputs"
490+
)
491+
_LOGGER.error(message)
492+
raise ValueError(message)
493+
481494
if endpoint_input is not None:
482495
monitoring_inputs = [
483496
self._normalize_endpoint_input(endpoint_input=endpoint_input)._to_request_dict()
484497
]
485498

499+
elif batch_transform_input is not None:
500+
monitoring_inputs = [batch_transform_input._to_request_dict()]
501+
486502
monitoring_output_config = None
487503
if output is not None:
488504
normalized_monitoring_output = self._normalize_monitoring_output_fields(output=output)
@@ -1895,6 +1911,7 @@ def update_monitoring_schedule(
18951911
network_config=None,
18961912
enable_cloudwatch_metrics=None,
18971913
role=None,
1914+
batch_transform_input=None,
18981915
):
18991916
"""Updates the existing monitoring schedule.
19001917
@@ -1936,8 +1953,20 @@ def update_monitoring_schedule(
19361953
enable_cloudwatch_metrics (bool): Whether to publish cloudwatch metrics as part of
19371954
the baselining or monitoring jobs.
19381955
role (str): An AWS IAM role name or ARN. The Amazon SageMaker jobs use this role.
1956+
batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to
1957+
run the monitoring schedule on the batch transform (default: None)
19391958
19401959
"""
1960+
1961+
if (batch_transform_input is not None) and (endpoint_input is not None):
1962+
message = (
1963+
"Cannot update both batch_transform_input and endpoint_input to update an "
1964+
"Amazon Model Monitoring Schedule. "
1965+
"Please provide atmost one of the above required inputs"
1966+
)
1967+
_LOGGER.error(message)
1968+
raise ValueError(message)
1969+
19411970
# check if this schedule is in v2 format and update as per v2 format if it is
19421971
if self.job_definition_name is not None:
19431972
self._update_data_quality_monitoring_schedule(
@@ -1958,13 +1987,17 @@ def update_monitoring_schedule(
19581987
network_config=network_config,
19591988
enable_cloudwatch_metrics=enable_cloudwatch_metrics,
19601989
role=role,
1990+
batch_transform_input=batch_transform_input,
19611991
)
19621992
return
19631993

19641994
monitoring_inputs = None
19651995
if endpoint_input is not None:
19661996
monitoring_inputs = [self._normalize_endpoint_input(endpoint_input)._to_request_dict()]
19671997

1998+
elif batch_transform_input is not None:
1999+
monitoring_inputs = [batch_transform_input._to_request_dict()]
2000+
19682001
record_preprocessor_script_s3_uri = None
19692002
if record_preprocessor_script is not None:
19702003
record_preprocessor_script_s3_uri = self._s3_uri_from_local_path(
@@ -3022,6 +3055,15 @@ def update_monitoring_schedule(
30223055
self._update_monitoring_schedule(self.job_definition_name, schedule_cron_expression)
30233056
return
30243057

3058+
if (batch_transform_input is not None) and (endpoint_input is not None):
3059+
message = (
3060+
"Cannot update both batch_transform_input and endpoint_input to update an "
3061+
"Amazon Model Monitoring Schedule. "
3062+
"Please provide atmost one of the above required inputs"
3063+
)
3064+
_LOGGER.error(message)
3065+
raise ValueError(message)
3066+
30253067
# Need to update schedule with a new job definition
30263068
job_desc = self.sagemaker_session.sagemaker_client.describe_model_quality_job_definition(
30273069
JobDefinitionName=self.job_definition_name

tests/integ/test_model_monitor.py

+192-2
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,6 @@ def updated_output_kms_key(sagemaker_session):
281281
tests.integ.test_region() in tests.integ.NO_MODEL_MONITORING_REGIONS,
282282
reason="ModelMonitoring is not yet supported in this region.",
283283
)
284-
@pytest.mark.release
285284
def test_default_monitoring_batch_transform_schedule_name(
286285
sagemaker_session, output_kms_key, volume_kms_key
287286
):
@@ -359,7 +358,6 @@ def test_default_monitoring_batch_transform_schedule_name(
359358
tests.integ.test_region() in tests.integ.NO_MODEL_MONITORING_REGIONS,
360359
reason="ModelMonitoring is not yet supported in this region.",
361360
)
362-
@pytest.mark.release
363361
def test_default_monitor_suggest_baseline_and_create_monitoring_schedule_with_customizations(
364362
sagemaker_session, output_kms_key, volume_kms_key, predictor
365363
):
@@ -1852,3 +1850,195 @@ def _verify_default_monitoring_schedule_with_batch_transform(
18521850
)
18531851
else:
18541852
assert network_config is None
1853+
1854+
1855+
def test_default_update_monitoring_batch_transform(
1856+
sagemaker_session, output_kms_key, volume_kms_key
1857+
):
1858+
my_default_monitor = DefaultModelMonitor(
1859+
role=ROLE,
1860+
instance_count=INSTANCE_COUNT,
1861+
instance_type=INSTANCE_TYPE,
1862+
volume_size_in_gb=VOLUME_SIZE_IN_GB,
1863+
volume_kms_key=volume_kms_key,
1864+
output_kms_key=output_kms_key,
1865+
max_runtime_in_seconds=MAX_RUNTIME_IN_SECONDS,
1866+
sagemaker_session=sagemaker_session,
1867+
env=ENVIRONMENT,
1868+
tags=TAGS,
1869+
network_config=NETWORK_CONFIG,
1870+
)
1871+
1872+
output_s3_uri = os.path.join(
1873+
"s3://",
1874+
sagemaker_session.default_bucket(),
1875+
"integ-test-monitoring-output-bucket",
1876+
str(uuid.uuid4()),
1877+
)
1878+
1879+
data_captured_destination_s3_uri = os.path.join(
1880+
"s3://",
1881+
sagemaker_session.default_bucket(),
1882+
"sagemaker-serving-batch-transform",
1883+
str(uuid.uuid4()),
1884+
)
1885+
1886+
batch_transform_input = BatchTransformInput(
1887+
data_captured_destination_s3_uri=data_captured_destination_s3_uri,
1888+
destination="/opt/ml/processing/output",
1889+
dataset_format=MonitoringDatasetFormat.csv(header=False),
1890+
)
1891+
1892+
statistics = Statistics.from_file_path(
1893+
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json"),
1894+
sagemaker_session=sagemaker_session,
1895+
)
1896+
1897+
constraints = Constraints.from_file_path(
1898+
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json"),
1899+
sagemaker_session=sagemaker_session,
1900+
)
1901+
1902+
my_default_monitor.create_monitoring_schedule(
1903+
batch_transform_input=batch_transform_input,
1904+
output_s3_uri=output_s3_uri,
1905+
statistics=statistics,
1906+
constraints=constraints,
1907+
schedule_cron_expression=HOURLY_CRON_EXPRESSION,
1908+
enable_cloudwatch_metrics=ENABLE_CLOUDWATCH_METRICS,
1909+
)
1910+
1911+
_wait_for_schedule_changes_to_apply(monitor=my_default_monitor)
1912+
1913+
data_captured_destination_s3_uri = os.path.join(
1914+
"s3://",
1915+
sagemaker_session.default_bucket(),
1916+
"sagemaker-tensorflow-serving-batch-transform",
1917+
str(uuid.uuid4()),
1918+
)
1919+
1920+
batch_transform_input = BatchTransformInput(
1921+
data_captured_destination_s3_uri=data_captured_destination_s3_uri,
1922+
destination="/opt/ml/processing/output",
1923+
dataset_format=MonitoringDatasetFormat.csv(header=False),
1924+
)
1925+
1926+
my_default_monitor.update_monitoring_schedule(
1927+
batch_transform_input=batch_transform_input,
1928+
)
1929+
1930+
_wait_for_schedule_changes_to_apply(monitor=my_default_monitor)
1931+
1932+
schedule_description = my_default_monitor.describe_schedule()
1933+
1934+
_verify_default_monitoring_schedule_with_batch_transform(
1935+
sagemaker_session=sagemaker_session,
1936+
schedule_description=schedule_description,
1937+
cron_expression=HOURLY_CRON_EXPRESSION,
1938+
statistics=statistics,
1939+
constraints=constraints,
1940+
output_kms_key=output_kms_key,
1941+
volume_kms_key=volume_kms_key,
1942+
network_config=NETWORK_CONFIG,
1943+
)
1944+
1945+
my_default_monitor.stop_monitoring_schedule()
1946+
my_default_monitor.delete_monitoring_schedule()
1947+
1948+
1949+
def test_byoc_monitoring_schedule_name_update_batch(
1950+
sagemaker_session, output_kms_key, volume_kms_key
1951+
):
1952+
byoc_env = ENVIRONMENT.copy()
1953+
byoc_env["dataset_format"] = json.dumps(DatasetFormat.csv(header=False))
1954+
byoc_env["dataset_source"] = "/opt/ml/processing/input/baseline_dataset_input"
1955+
byoc_env["output_path"] = os.path.join("/opt/ml/processing/output")
1956+
byoc_env["publish_cloudwatch_metrics"] = "Disabled"
1957+
1958+
my_byoc_monitor = ModelMonitor(
1959+
role=ROLE,
1960+
image_uri=DefaultModelMonitor._get_default_image_uri(
1961+
sagemaker_session.boto_session.region_name
1962+
),
1963+
instance_count=INSTANCE_COUNT,
1964+
instance_type=INSTANCE_TYPE,
1965+
volume_size_in_gb=VOLUME_SIZE_IN_GB,
1966+
volume_kms_key=volume_kms_key,
1967+
output_kms_key=output_kms_key,
1968+
max_runtime_in_seconds=MAX_RUNTIME_IN_SECONDS,
1969+
sagemaker_session=sagemaker_session,
1970+
env=byoc_env,
1971+
tags=TAGS,
1972+
network_config=NETWORK_CONFIG,
1973+
)
1974+
1975+
output_s3_uri = os.path.join(
1976+
"s3://",
1977+
sagemaker_session.default_bucket(),
1978+
"integ-test-monitoring-output-bucket",
1979+
str(uuid.uuid4()),
1980+
)
1981+
1982+
statistics = Statistics.from_file_path(
1983+
statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json"),
1984+
sagemaker_session=sagemaker_session,
1985+
)
1986+
1987+
constraints = Constraints.from_file_path(
1988+
constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json"),
1989+
sagemaker_session=sagemaker_session,
1990+
)
1991+
1992+
data_captured_destination_s3_uri = os.path.join(
1993+
"s3://",
1994+
sagemaker_session.default_bucket(),
1995+
"sagemaker-serving-batch-transform",
1996+
str(uuid.uuid4()),
1997+
)
1998+
1999+
batch_transform_input = BatchTransformInput(
2000+
data_captured_destination_s3_uri=data_captured_destination_s3_uri,
2001+
destination="/opt/ml/processing/output",
2002+
dataset_format=MonitoringDatasetFormat.csv(header=False),
2003+
)
2004+
2005+
my_byoc_monitor.create_monitoring_schedule(
2006+
endpoint_input=batch_transform_input,
2007+
output=MonitoringOutput(source="/opt/ml/processing/output", destination=output_s3_uri),
2008+
statistics=statistics,
2009+
constraints=constraints,
2010+
schedule_cron_expression=HOURLY_CRON_EXPRESSION,
2011+
)
2012+
2013+
_wait_for_schedule_changes_to_apply(monitor=my_byoc_monitor)
2014+
2015+
data_captured_destination_s3_uri = os.path.join(
2016+
"s3://",
2017+
sagemaker_session.default_bucket(),
2018+
"sagemaker-tensorflow-serving-batch-transform",
2019+
str(uuid.uuid4()),
2020+
)
2021+
2022+
batch_transform_input = BatchTransformInput(
2023+
data_captured_destination_s3_uri=data_captured_destination_s3_uri,
2024+
destination="/opt/ml/processing/output",
2025+
dataset_format=MonitoringDatasetFormat.csv(header=False),
2026+
)
2027+
2028+
my_byoc_monitor.update_monitoring_schedule(
2029+
batch_transform_input=batch_transform_input,
2030+
)
2031+
2032+
_wait_for_schedule_changes_to_apply(monitor=my_byoc_monitor)
2033+
2034+
schedule_description = my_byoc_monitor.describe_schedule()
2035+
2036+
assert (
2037+
data_captured_destination_s3_uri
2038+
== schedule_description["MonitoringScheduleConfig"]["MonitoringJobDefinition"][
2039+
"MonitoringInputs"
2040+
][0]["BatchTransformInput"]["DataCapturedDestinationS3Uri"]
2041+
)
2042+
2043+
my_byoc_monitor.stop_monitoring_schedule()
2044+
my_byoc_monitor.delete_monitoring_schedule()

tests/unit/sagemaker/monitor/test_model_monitoring.py

+44
Original file line numberDiff line numberDiff line change
@@ -1836,3 +1836,47 @@ def test_model_monitor_with_arguments(
18361836
role_arn=ROLE,
18371837
tags=TAGS,
18381838
)
1839+
1840+
1841+
def test_update_model_monitor_error_with_endpoint_and_batch(
1842+
model_monitor_arguments,
1843+
data_quality_monitor,
1844+
endpoint_input=EndpointInput(
1845+
endpoint_name=ENDPOINT_NAME,
1846+
destination=ENDPOINT_INPUT_LOCAL_PATH,
1847+
start_time_offset=START_TIME_OFFSET,
1848+
end_time_offset=END_TIME_OFFSET,
1849+
features_attribute=FEATURES_ATTRIBUTE,
1850+
inference_attribute=INFERENCE_ATTRIBUTE,
1851+
probability_attribute=PROBABILITY_ATTRIBUTE,
1852+
probability_threshold_attribute=PROBABILITY_THRESHOLD_ATTRIBUTE,
1853+
),
1854+
batch_transform_input=BatchTransformInput(
1855+
data_captured_destination_s3_uri=DATA_CAPTURED_S3_URI,
1856+
destination=SCHEDULE_DESTINATION,
1857+
dataset_format=MonitoringDatasetFormat.csv(header=False),
1858+
),
1859+
):
1860+
try:
1861+
model_monitor_arguments.update_monitoring_schedule(
1862+
schedule_cron_expression=CRON_HOURLY,
1863+
endpoint_input=endpoint_input,
1864+
arguments=ARGUMENTS,
1865+
output=MonitoringOutput(source="/opt/ml/processing/output", destination=OUTPUT_S3_URI),
1866+
batch_transform_input=batch_transform_input,
1867+
)
1868+
except ValueError as error:
1869+
assert "Cannot update both batch_transform_input and endpoint_input to update an" in str(
1870+
error
1871+
)
1872+
1873+
try:
1874+
data_quality_monitor.update_monitoring_schedule(
1875+
schedule_cron_expression=CRON_HOURLY,
1876+
endpoint_input=endpoint_input,
1877+
batch_transform_input=batch_transform_input,
1878+
)
1879+
except ValueError as error:
1880+
assert "Cannot update both batch_transform_input and endpoint_input to update an" in str(
1881+
error
1882+
)

0 commit comments

Comments
 (0)